Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
fengjialin authored and fengjialin committed Mar 7, 2024
1 parent 18702ea commit 2b4eeec
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _import_baiduvectordb() -> Any:

return BaiduVectorDB


def _import_baiducloud_vector_search() -> Any:
from langchain_community.vectorstores.baiducloud_vector_search import BESVectorStore

Expand Down
84 changes: 49 additions & 35 deletions libs/community/langchain_community/vectorstores/baiduvectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ class ConnectionParams:
connection_timeout_in_mills (int) : Request Timeout.
"""

def __init__(self, endpoint: str, api_key: str, account: str = "root",
connection_timeout_in_mills: int = 50*1000):
def __init__(
self,
endpoint: str,
api_key: str,
account: str = "root",
connection_timeout_in_mills: int = 50 * 1000,
):
self.endpoint = endpoint
self.api_key = api_key
self.account = account
Expand Down Expand Up @@ -96,8 +101,9 @@ def __init__(
self.embedding_func = embedding
self.table_params = table_params
config = configuration.Configuration(
credentials=auth.BceCredentials(connection_params.account,
connection_params.api_key),
credentials=auth.BceCredentials(
connection_params.account, connection_params.api_key
),
endpoint=connection_params.endpoint,
connection_timeout_in_mills=connection_params.connection_timeout_in_mills,
)
Expand Down Expand Up @@ -144,40 +150,43 @@ def _create_table(self, table_name: str) -> None:
fields = []
fields.append(
schema.Field(
self.field_id,
self.mochowenum.FieldType.STRING,
self.field_id,
self.mochowenum.FieldType.STRING,
primary_key=True,
partition_key=True,
partition_key=True,
auto_increment=False,
not_null=True
not_null=True,
)
)
fields.append(
schema.Field(
self.field_vector,
self.field_vector,
self.mochowenum.FieldType.FLOAT_VECTOR,
dimension=self.table_params.dimension
dimension=self.table_params.dimension,
)
)
fields.append(schema.Field(self.field_text, self.mochowenum.FieldType.STRING))
fields.append(
schema.Field(
self.field_metadata,
self.mochowenum.FieldType.STRING
)
schema.Field(self.field_metadata, self.mochowenum.FieldType.STRING)
)
indexes = []
indexes.append(schema.VectorIndex(index_name=self.index_vector,
index_type=index_type,
field=self.field_vector,
metric_type=metric_type,
params=params))
indexes.append(
schema.VectorIndex(
index_name=self.index_vector,
index_type=index_type,
field=self.field_vector,
metric_type=metric_type,
params=params,
)
)

self.table = self.database.create_table(
table_name=table_name,
replication=self.table_params.replication,
partition=self.mochowtable.Partition(partition_num=self.table_params.partition),
schema=schema.Schema(fields=fields, indexes=indexes)
partition=self.mochowtable.Partition(
partition_num=self.table_params.partition
),
schema=schema.Schema(fields=fields, indexes=indexes),
)
# wait table ready
time.sleep(20)
Expand Down Expand Up @@ -323,10 +332,12 @@ def _similarity_search_with_score(
"""Perform a search on a query string and return results with score."""
ef = 10 if param is None else param.get("ef", 10)

anns = self.mochowtable.AnnSearch(vector_field=self.field_vector,
vector_floats=[float(num) for num in embedding],
params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
filter=expr)
anns = self.mochowtable.AnnSearch(
vector_field=self.field_vector,
vector_floats=[float(num) for num in embedding],
params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
filter=expr,
)
res = self.table.search(anns=anns)

rows = [[item] for item in res.rows]
Expand All @@ -336,12 +347,13 @@ def _similarity_search_with_score(
return ret
for row in rows:
for result in row:
row_data = result.get('row', {})
row_data = result.get("row", {})
meta = row_data.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
doc = Document(page_content=row_data.get(self.field_text),
metadata=meta)
doc = Document(
page_content=row_data.get(self.field_text), metadata=meta
)
pair = (doc, result.get("distance", 0.0))
ret.append(pair)
return ret
Expand Down Expand Up @@ -380,10 +392,12 @@ def _max_marginal_relevance_search(
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR."""
ef = 10 if param is None else param.get("ef", 10)
anns = self.mochowtable.AnnSearch(vector_field=self.field_vector,
vector_floats=[float(num) for num in embedding],
params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
filter=expr)
anns = self.mochowtable.AnnSearch(
vector_field=self.field_vector,
vector_floats=[float(num) for num in embedding],
params=self.mochowtable.HNSWSearchParams(ef=ef, limit=k),
filter=expr,
)
res = self.table.search(anns=anns, retrieve_vector=True)

# Organize results.
Expand All @@ -394,12 +408,12 @@ def _max_marginal_relevance_search(
return documents
for row in rows:
for result in row:
row_data = result.get('row', {})
row_data = result.get("row", {})
meta = row_data.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
doc = Document(page_content=row_data.get(self.field_text),
metadata=meta)
doc = Document(
page_content=row_data.get(self.field_text), metadata=meta)
documents.append(doc)
ordered_result_embeddings.append(row_data.get(self.field_vector))
# Get the new order of results.
Expand Down

0 comments on commit 2b4eeec

Please sign in to comment.