Skip to content

Commit

Permalink
openai[patch]: fix azure embedding length check (#19870)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis committed Apr 1, 2024
1 parent d62e84c commit be92cf5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
25 changes: 18 additions & 7 deletions libs/partners/openai/langchain_openai/embeddings/azure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Azure OpenAI embeddings wrapper."""

from __future__ import annotations

import os
Expand Down Expand Up @@ -57,6 +58,8 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
validate_base_url: bool = True
chunk_size: int = 2048
"""Maximum number of texts to embed in each batch"""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -102,7 +105,11 @@ def validate_environment(cls, values: Dict) -> Dict:
# Azure OpenAI embedding models allow a maximum of 2048 texts
# at a time in each batch
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/embeddings?tabs=console#best-practices
values["chunk_size"] = min(values["chunk_size"], 2048)
if values["chunk_size"] > 2048:
raise ValueError(
"Azure OpenAI embeddings only allow a maximum of 2048 texts at a time "
"in each batch."
)
# For backwards compatibility. Before openai v1, no distinction was made
# between azure_endpoint and base_url (openai_api_base).
openai_api_base = values["openai_api_base"]
Expand All @@ -126,12 +133,16 @@ def validate_environment(cls, values: Dict) -> Dict:
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["deployment"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"azure_ad_token": values["azure_ad_token"].get_secret_value()
if values["azure_ad_token"]
else None,
"api_key": (
values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None
),
"azure_ad_token": (
values["azure_ad_token"].get_secret_value()
if values["azure_ad_token"]
else None
),
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def test_azure_openai_embedding_documents_chunk_size() -> None:
embedding = _get_embeddings()
embedding.embedding_ctx_length = 8191
output = embedding.embed_documents(documents)
# Max 16 chunks per batch on Azure OpenAI embeddings
assert embedding.chunk_size == 16
# Max 2048 chunks per batch on Azure OpenAI embeddings
assert embedding.chunk_size == 2048
assert len(output) == 20
assert all([len(out) == 1536 for out in output])

Expand Down

0 comments on commit be92cf5

Please sign in to comment.