Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Databricks SerDe uses cloudpickle instead of pickle #18607

Merged
merged 15 commits into from
Mar 6, 2024
15 changes: 12 additions & 3 deletions libs/community/langchain_community/llms/databricks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import pickle
import re
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -225,7 +224,12 @@ def _is_hex_string(data: str) -> bool:
def _load_pickled_fn_from_hex_string(data: str) -> Callable:
"""Loads a pickled function from a hexadecimal string."""
try:
return pickle.loads(bytes.fromhex(data))
import cloudpickle
except Exception as e:
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")

try:
return cloudpickle.loads(bytes.fromhex(data))
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may want to separately catch import errors so you can provide a helpful error message (e.g. "please install cloudpickle with ...")

raise ValueError(
f"Failed to load the pickled function from a hexadecimal string. Error: {e}"
Expand All @@ -235,7 +239,12 @@ def _load_pickled_fn_from_hex_string(data: str) -> Callable:
def _pickle_fn_to_hex_string(fn: Callable) -> str:
"""Pickles a function and returns the hexadecimal string."""
try:
return pickle.dumps(fn).hex()
import cloudpickle
except Exception as e:
raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}")

try:
return cloudpickle.dumps(fn).hex()
except Exception as e:
raise ValueError(f"Failed to pickle the function: {e}")

Expand Down
8 changes: 4 additions & 4 deletions libs/community/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion libs/community/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ hologres-vector = {version = "^0.0.6", optional = true}
praw = {version = "^7.7.1", optional = true}
msal = {version = "^1.25.0", optional = true}
databricks-vectorsearch = {version = "^0.21", optional = true}
cloudpickle = {version = ">=2.0.0", optional = true}
dgml-utils = {version = "^0.3.0", optional = true}
datasets = {version = "^2.15.0", optional = true}
tree-sitter = {version = "^0.20.2", optional = true}
Expand Down Expand Up @@ -249,6 +250,7 @@ extended_testing = [
"hologres-vector",
"praw",
"databricks-vectorsearch",
"cloudpickle",
"dgml-utils",
"cohere",
"tree-sitter",
Expand All @@ -260,7 +262,8 @@ extended_testing = [
"elasticsearch",
"hdbcli",
"oci",
"rdflib"
"rdflib",
"cloudpickle",
]

[tool.ruff]
Expand Down
16 changes: 13 additions & 3 deletions libs/community/tests/unit_tests/llms/test_databricks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""test Databricks LLM"""
import pickle
from typing import Any, Dict

import pytest
from pytest import MonkeyPatch

from langchain_community.llms.databricks import Databricks
from langchain_community.llms.databricks import (
Databricks,
_load_pickled_fn_from_hex_string,
)


class MockDatabricksServingEndpointClient:
Expand All @@ -29,7 +32,10 @@ def transform_input(**request: Any) -> Dict[str, Any]:
return request


@pytest.mark.requires("cloudpickle")
def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
import cloudpickle

monkeypatch.setattr(
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
MockDatabricksServingEndpointClient,
Expand All @@ -42,5 +48,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
transform_input_fn=transform_input,
)
params = llm._default_params
pickled_string = pickle.dumps(transform_input).hex()
pickled_string = cloudpickle.dumps(transform_input).hex()
assert params["transform_input_fn"] == pickled_string

request = {"prompt": "What is the meaning of life?"}
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
assert fn(**request) == transform_input(**request)