Skip to content

Commit

Permalink
add openvino rerank import test
Browse files Browse the repository at this point in the history
fix format

fix format

fix format

fix format

fix format

fix format

fix format
  • Loading branch information
OpenVINO-dev-contest committed Mar 30, 2024
1 parent e2b40e2 commit 809fa81
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 11 deletions.
20 changes: 12 additions & 8 deletions docs/docs/integrations/retrievers/openvino_rerank.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"source": [
"# Helper function for printing docs\n",
"\n",
"\n",
"def pretty_print_docs(docs):\n",
" print(\n",
" f\"\\n{'-' * 100}\\n\".join(\n",
Expand Down Expand Up @@ -365,24 +366,23 @@
}
],
"source": [
"from langchain.embeddings import OpenVINOEmbeddings\n",
"from langchain_community.document_loaders import TextLoader\n",
"from langchain_community.vectorstores import FAISS\n",
"from langchain.embeddings import OpenVINOEmbeddings\n",
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
"\n",
"documents = TextLoader(\n",
" \"../../modules/state_of_the_union.txt\",\n",
").load()\n",
"text_splitter = RecursiveCharacterTextSplitter(\n",
" chunk_size=500, chunk_overlap=100)\n",
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
"texts = text_splitter.split_documents(documents)\n",
"for idx, text in enumerate(texts):\n",
" text.metadata[\"id\"] = idx\n",
"\n",
"embedding = OpenVINOEmbeddings(\n",
" model_name_or_path=\"sentence-transformers/all-mpnet-base-v2\")\n",
"retriever = FAISS.from_documents(\n",
" texts, embedding).as_retriever(search_kwargs={\"k\": 20})\n",
" model_name_or_path=\"sentence-transformers/all-mpnet-base-v2\"\n",
")\n",
"retriever = FAISS.from_documents(texts, embedding).as_retriever(search_kwargs={\"k\": 20})\n",
"\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = retriever.get_relevant_documents(query)\n",
Expand Down Expand Up @@ -550,11 +550,15 @@
],
"source": [
"from pathlib import Path\n",
"\n",
"ov_model_dir = \"bge-reranker-large-ov\"\n",
"if not Path(ov_model_dir).exists():\n",
" from optimum.intel.openvino import OVModelForSequenceClassification\n",
" from transformers import AutoTokenizer\n",
" ov_model = OVModelForSequenceClassification.from_pretrained(model_name, compile=False, export=True)\n",
"\n",
" ov_model = OVModelForSequenceClassification.from_pretrained(\n",
" model_name, compile=False, export=True\n",
" )\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" ov_model.half()\n",
" ov_model.save_pretrained(ov_model_dir)\n",
Expand All @@ -575,7 +579,7 @@
}
],
"source": [
"compressor = OpenVINOReranker(model_name_or_path=ov_model_dir)\n"
"compressor = OpenVINOReranker(model_name_or_path=ov_model_dir)"
]
},
{
Expand Down
6 changes: 5 additions & 1 deletion docs/docs/integrations/text_embedding/openvino.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,15 @@
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"ov_model_dir = \"all-mpnet-base-v2-ov\"\n",
"if not Path(ov_model_dir).exists():\n",
" from optimum.intel.openvino import OVModelForFeatureExtraction\n",
" from transformers import AutoTokenizer\n",
" ov_model = OVModelForFeatureExtraction.from_pretrained(model_name, compile=False, export=True)\n",
"\n",
" ov_model = OVModelForFeatureExtraction.from_pretrained(\n",
" model_name, compile=False, export=True\n",
" )\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" ov_model.half()\n",
" ov_model.save_pretrained(ov_model_dir)\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def require_model_export(

self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)

def rerank(self, request: Any):
def rerank(self, request: Any) -> Any:
query = request.query
passages = request.passages

Expand Down
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/retrievers/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"MetalRetriever",
"MilvusRetriever",
"OutlineRetriever",
"OpenVINOReranker",
"PineconeHybridSearchRetriever",
"PubMedRetriever",
"QdrantSparseVectorRetriever",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

__all__ = [
"OpenVINOReranker",
]
]

0 comments on commit 809fa81

Please sign in to comment.