Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: support Milvus more params #15447

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 43 additions & 7 deletions libs/community/langchain_community/vectorstores/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __init__(
text_field: str = "text",
vector_field: str = "vector",
metadata_field: Optional[str] = None,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
):
"""Initialize the Milvus vector store."""
try:
Expand Down Expand Up @@ -158,6 +161,10 @@ def __init__(
self._vector_field = vector_field
self._metadata_field = metadata_field
self.fields: list[str] = []
self.partition_names = partition_names
self.replica_number = replica_number
self.timeout = timeout

# Create the connection to the server
if connection_args is None:
connection_args = DEFAULT_MILVUS_CONNECTION
Expand All @@ -176,7 +183,11 @@ def __init__(
self.col = None

# Initialize the vector store
self._init()
self._init(
partition_names=partition_names,
replica_number=replica_number,
timeout=timeout,
)

@property
def embeddings(self) -> Embeddings:
Expand Down Expand Up @@ -235,14 +246,23 @@ def _create_connection_alias(self, connection_args: dict) -> str:
raise e

def _init(
self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
self,
embeddings: Optional[list] = None,
metadatas: Optional[list[dict]] = None,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
) -> None:
if embeddings is not None:
self._create_collection(embeddings, metadatas)
self._extract_fields()
self._create_index()
self._create_search_params()
self._load()
self._load(
partition_names=partition_names,
replica_number=replica_number,
timeout=timeout,
)

def _create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None
Expand Down Expand Up @@ -396,12 +416,21 @@ def _create_search_params(self) -> None:
self.search_params = self.default_search_params[index_type]
self.search_params["metric_type"] = metric_type

def _load(self) -> None:
def _load(
self,
partition_names: Optional[list] = None,
replica_number: int = 1,
timeout: Optional[float] = None,
) -> None:
"""Load the collection if available."""
from pymilvus import Collection

if isinstance(self.col, Collection) and self._get_index() is not None:
self.col.load()
self.col.load(
partition_names=partition_names,
replica_number=replica_number,
timeout=timeout,
)

def add_texts(
self,
Expand All @@ -417,7 +446,7 @@ def add_texts(
in creating a new Collection. The data of the first entity decides
the schema of the new collection, the dim is extracted from the first
embedding and the columns are decided by the first metadata dict.
Metada keys will need to be present for all inserted values. At
Metadata keys will need to be present for all inserted values. At
the moment there is no None equivalent in Milvus.

Args:
Expand Down Expand Up @@ -451,7 +480,14 @@ def add_texts(

# If the collection hasn't been initialized yet, perform all steps to do so
if not isinstance(self.col, Collection):
self._init(embeddings, metadatas)
kwargs = {"embeddings": embeddings, "metadatas": metadatas}
if self.partition_names:
kwargs["partition_names"] = self.partition_names
if self.replica_number:
kwargs["replica_number"] = self.replica_number
if self.timeout:
kwargs["timeout"] = self.timeout
self._init(**kwargs)

# Dict to hold all insert columns
insert_dict: dict[str, list] = {
Expand Down