Skip to content

Commit

Permalink
community[patch]: Databricks SerDe uses cloudpickle instead of pickle (
Browse files Browse the repository at this point in the history
…langchain-ai#18607)

- **Description:** Databricks SerDe uses cloudpickle instead of pickle
when serializing a user-defined function transform_input_fn since pickle
does not support functions defined in `__main__`, and cloudpickle
supports this.
- **Dependencies:** cloudpickle>=2.0.0

Added a unit test.
  • Loading branch information
liangz1 authored and Dave Bechberger committed Mar 29, 2024
1 parent 36d2a77 commit 847f24f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
15 changes: 12 additions & 3 deletions libs/community/langchain_community/llms/databricks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import pickle
import re
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -225,7 +224,12 @@ def _is_hex_string(data: str) -> bool:
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
"""Loads a pickled function from a hexadecimal string."""
try:
return pickle.loads(bytes.fromhex(data))
import cloudpickle
except Exception as e:
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")

try:
return cloudpickle.loads(bytes.fromhex(data))
except Exception as e:
raise ValueError(
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
Expand All @@ -235,7 +239,12 @@ def _load_pickled_fn_from_hex_string(data: str) -> Callable:
def _pickle_fn_to_hex_string(fn: Callable) -> str:
"""Pickles a function and returns the hexadecimal string."""
try:
return pickle.dumps(fn).hex()
import cloudpickle
except Exception as e:
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")

try:
return cloudpickle.dumps(fn).hex()
except Exception as e:
raise ValueError(f"Failed to pickle the function: {e}")

Expand Down
8 changes: 4 additions & 4 deletions libs/community/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion libs/community/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ hologres-vector = {version = "^0.0.6", optional = true}
praw = {version = "^7.7.1", optional = true}
msal = {version = "^1.25.0", optional = true}
databricks-vectorsearch = {version = "^0.21", optional = true}
cloudpickle = {version = ">=2.0.0", optional = true}
dgml-utils = {version = "^0.3.0", optional = true}
datasets = {version = "^2.15.0", optional = true}
tree-sitter = {version = "^0.20.2", optional = true}
Expand Down Expand Up @@ -249,6 +250,7 @@ extended_testing = [
"hologres-vector",
"praw",
"databricks-vectorsearch",
"cloudpickle",
"dgml-utils",
"cohere",
"tree-sitter",
Expand All @@ -260,7 +262,8 @@ extended_testing = [
"elasticsearch",
"hdbcli",
"oci",
"rdflib"
"rdflib",
"cloudpickle",
]

[tool.ruff]
Expand Down
16 changes: 13 additions & 3 deletions libs/community/tests/unit_tests/llms/test_databricks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""test Databricks LLM"""
import pickle
from typing import Any, Dict

import pytest
from pytest import MonkeyPatch

from langchain_community.llms.databricks import Databricks
from langchain_community.llms.databricks import (
Databricks,
_load_pickled_fn_from_hex_string,
)


class MockDatabricksServingEndpointClient:
Expand All @@ -29,7 +32,10 @@ def transform_input(**request: Any) -> Dict[str, Any]:
return request


@pytest.mark.requires("cloudpickle")
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
import cloudpickle

monkeypatch.setattr(
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
MockDatabricksServingEndpointClient,
Expand All @@ -42,5 +48,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
transform_input_fn=transform_input,
)
params = llm._default_params
pickled_string = pickle.dumps(transform_input).hex()
pickled_string = cloudpickle.dumps(transform_input).hex()
assert params["transform_input_fn"] == pickled_string

request = {"prompt": "What is the meaning of life?"}
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
assert fn(**request) == transform_input(**request)

0 comments on commit 847f24f

Please sign in to comment.