Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9a95db9

Browse files
authoredFeb 6, 2025
feat(embeddings): use stdlib array type for improved performance (#2060)
1 parent e0ca9f0 commit 9a95db9

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed
 

‎src/openai/resources/embeddings.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import array
56
import base64
67
from typing import List, Union, Iterable, cast
78
from typing_extensions import Literal
@@ -102,7 +103,7 @@ def create(
102103
"dimensions": dimensions,
103104
"encoding_format": encoding_format,
104105
}
105-
if not is_given(encoding_format) and has_numpy():
106+
if not is_given(encoding_format):
106107
params["encoding_format"] = "base64"
107108

108109
def parser(obj: CreateEmbeddingResponse) -> CreateEmbeddingResponse:
@@ -113,12 +114,14 @@ def parser(obj: CreateEmbeddingResponse) -> CreateEmbeddingResponse:
113114
for embedding in obj.data:
114115
data = cast(object, embedding.embedding)
115116
if not isinstance(data, str):
116-
# numpy is not installed / base64 optimisation isn't enabled for this model yet
117117
continue
118-
119-
embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call]
120-
base64.b64decode(data), dtype="float32"
121-
).tolist()
118+
if not has_numpy():
119+
# use array for base64 optimisation
120+
embedding.embedding = array.array("f", base64.b64decode(data)).tolist()
121+
else:
122+
embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call]
123+
base64.b64decode(data), dtype="float32"
124+
).tolist()
122125

123126
return obj
124127

@@ -215,7 +218,7 @@ async def create(
215218
"dimensions": dimensions,
216219
"encoding_format": encoding_format,
217220
}
218-
if not is_given(encoding_format) and has_numpy():
221+
if not is_given(encoding_format):
219222
params["encoding_format"] = "base64"
220223

221224
def parser(obj: CreateEmbeddingResponse) -> CreateEmbeddingResponse:
@@ -226,12 +229,14 @@ def parser(obj: CreateEmbeddingResponse) -> CreateEmbeddingResponse:
226229
for embedding in obj.data:
227230
data = cast(object, embedding.embedding)
228231
if not isinstance(data, str):
229-
# numpy is not installed / base64 optimisation isn't enabled for this model yet
230232
continue
231-
232-
embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call]
233-
base64.b64decode(data), dtype="float32"
234-
).tolist()
233+
if not has_numpy():
234+
# use array for base64 optimisation
235+
embedding.embedding = array.array("f", base64.b64decode(data)).tolist()
236+
else:
237+
embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call]
238+
base64.b64decode(data), dtype="float32"
239+
).tolist()
235240

236241
return obj
237242

0 commit comments

Comments
 (0)
Please sign in to comment.