forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: Add SparkLLM Text Embedding Model and SparkLLM intr…
…oduction (langchain-ai#17573)
- Loading branch information
Showing
6 changed files
with
323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# SparkLLM | ||
|
||
>[SparkLLM](https://xinghuo.xfyun.cn/spark) is a large-scale cognitive model independently developed by iFLYTEK. | ||
It has cross-domain knowledge and language understanding ability by learning a large amount of texts, codes and images. | ||
It can understand and perform tasks based on natural dialogue. | ||
|
||
## SparkLLM Chat Model | ||
An example is available at [example](/docs/integrations/chat/sparkllm). | ||
|
||
## SparkLLM Text Embedding Model | ||
An example is available at [example](/docs/integrations/text_embedding/sparkllm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# SparkLLM Text Embeddings" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Official Website: https://www.xfyun.cn/doc/spark/Embedding_new_api.html\n", | ||
"\n", | ||
"An API key is required to use this embedding model. You can get one by registering at https://platform.SparkLLM-ai.com/docs/text-Embedding." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"SparkLLMTextEmbeddings support 2K token window and preduces vectors with 2560 dimensions." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.embeddings import SparkLLMTextEmbeddings\n", | ||
"\n", | ||
"embeddings = SparkLLMTextEmbeddings(\n", | ||
" spark_app_id=\"sk-*\", spark_api_key=\"\", spark_api_secret=\"\"\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Alternatively, you can set API key this way:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"os.environ[\"SPARK_APP_ID\"] = \"YOUR_APP_ID\"\n", | ||
"os.environ[\"SPARK_API_KEY\"] = \"YOUR_API_KEY\"\n", | ||
"os.environ[\"SPARK_API_SECRET\"] = \"YOUR_API_SECRET\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"text_1 = \"iFLYTEK is a well-known intelligent speech and artificial intelligence publicly listed company in the Asia-Pacific Region. Since its establishment, the company is devoted to cornerstone technological research in speech and languages, natural language understanding, machine learning, machine reasoning, adaptive learning, and has maintained the world-leading position in those domains. The company actively promotes the development of A.I. products and their sector-based applications, with visions of enabling machines to listen and speak, understand and think, creating a better world with artificial intelligence.\"\n", | ||
"text_2 = \"iFLYTEK Open Platform was launched in 2010 by iFLYTEK as China’s first Artificial Intelligence open platform for Mobile Internet and intelligent hardware developers.\"\n", | ||
"\n", | ||
"query_result = embeddings.embed_query(text_2)\n", | ||
"query_result" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"doc_result = embeddings.embed_documents([text_1, text_2])\n", | ||
"doc_result" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
184 changes: 184 additions & 0 deletions
184
libs/community/langchain_community/embeddings/sparkllm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
import base64 | ||
import hashlib | ||
import hmac | ||
import json | ||
import logging | ||
from datetime import datetime | ||
from time import mktime | ||
from typing import Any, Dict, List, Optional | ||
from urllib.parse import urlencode | ||
from wsgiref.handlers import format_date_time | ||
|
||
import numpy as np | ||
import requests | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator | ||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env | ||
from numpy import ndarray | ||
|
||
# Used for document and knowledge embedding | ||
EMBEDDING_P_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/sa8a05c27" | ||
# Used for user questions embedding | ||
EMBEDDING_Q_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/s50d55a16" | ||
|
||
# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/). | ||
|
||
# Official Website: https://www.xfyun.cn/doc/spark/Embedding_new_api.html | ||
# Developers need to create an application in the console first, use the appid, APIKey, | ||
# and APISecret provided in the application for authentication, | ||
# and generate an authentication URL for handshake. | ||
# You can get one by registering at https://console.xfyun.cn/services/bm3. | ||
# SparkLLMTextEmbeddings support 2K token window and preduces vectors with | ||
# 2560 dimensions. | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Url: | ||
def __init__(self, host: str, path: str, schema: str) -> None: | ||
self.host = host | ||
self.path = path | ||
self.schema = schema | ||
pass | ||
|
||
|
||
class SparkLLMTextEmbeddings(BaseModel, Embeddings): | ||
"""SparkLLM Text Embedding models.""" | ||
|
||
spark_app_id: SecretStr | ||
spark_api_key: SecretStr | ||
spark_api_secret: SecretStr | ||
|
||
@root_validator(allow_reuse=True) | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that auth token exists in environment.""" | ||
cls.spark_app_id = convert_to_secret_str( | ||
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID") | ||
) | ||
cls.spark_api_key = convert_to_secret_str( | ||
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY") | ||
) | ||
cls.spark_api_secret = convert_to_secret_str( | ||
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET") | ||
) | ||
return values | ||
|
||
def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]: | ||
url = self._assemble_ws_auth_url( | ||
request_url=host, | ||
method="POST", | ||
api_key=self.spark_api_key.get_secret_value(), | ||
api_secret=self.spark_api_secret.get_secret_value(), | ||
) | ||
content = self._get_body(self.spark_app_id.get_secret_value(), texts) | ||
response = requests.post( | ||
url, json=content, headers={"content-type": "application/json"} | ||
).text | ||
res_arr = self._parser_message(response) | ||
if res_arr is not None: | ||
return res_arr.tolist() | ||
return None | ||
|
||
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override] | ||
"""Public method to get embeddings for a list of documents. | ||
Args: | ||
texts: The list of texts to embed. | ||
Returns: | ||
A list of embeddings, one for each text, or None if an error occurs. | ||
""" | ||
return self._embed(texts, EMBEDDING_P_API_URL) | ||
|
||
def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override] | ||
"""Public method to get embedding for a single query text. | ||
Args: | ||
text: The text to embed. | ||
Returns: | ||
Embeddings for the text, or None if an error occurs. | ||
""" | ||
result = self._embed([text], EMBEDDING_Q_API_URL) | ||
return result[0] if result is not None else None | ||
|
||
@staticmethod | ||
def _assemble_ws_auth_url( | ||
request_url: str, method: str = "GET", api_key: str = "", api_secret: str = "" | ||
) -> str: | ||
u = SparkLLMTextEmbeddings._parse_url(request_url) | ||
host = u.host | ||
path = u.path | ||
now = datetime.now() | ||
date = format_date_time(mktime(now.timetuple())) | ||
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format( | ||
host, date, method, path | ||
) | ||
signature_sha = hmac.new( | ||
api_secret.encode("utf-8"), | ||
signature_origin.encode("utf-8"), | ||
digestmod=hashlib.sha256, | ||
).digest() | ||
signature_sha_str = base64.b64encode(signature_sha).decode(encoding="utf-8") | ||
authorization_origin = ( | ||
'api_key="%s", algorithm="%s", headers="%s", signature="%s"' | ||
% (api_key, "hmac-sha256", "host date request-line", signature_sha_str) | ||
) | ||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( | ||
encoding="utf-8" | ||
) | ||
values = {"host": host, "date": date, "authorization": authorization} | ||
|
||
return request_url + "?" + urlencode(values) | ||
|
||
@staticmethod | ||
def _parse_url(request_url: str) -> Url: | ||
stidx = request_url.index("://") | ||
host = request_url[stidx + 3 :] | ||
schema = request_url[: stidx + 3] | ||
edidx = host.index("/") | ||
if edidx <= 0: | ||
raise AssembleHeaderException("invalid request url:" + request_url) | ||
path = host[edidx:] | ||
host = host[:edidx] | ||
u = Url(host, path, schema) | ||
return u | ||
|
||
@staticmethod | ||
def _get_body(appid: str, text: List[str]) -> Dict[str, Any]: | ||
body = { | ||
"header": {"app_id": appid, "uid": "39769795890", "status": 3}, | ||
"parameter": {"emb": {"feature": {"encoding": "utf8"}}}, | ||
"payload": { | ||
"messages": { | ||
"text": base64.b64encode(json.dumps(text).encode("utf-8")).decode() | ||
} | ||
}, | ||
} | ||
return body | ||
|
||
@staticmethod | ||
def _parser_message( | ||
message: str, | ||
) -> Optional[ndarray]: | ||
data = json.loads(message) | ||
code = data["header"]["code"] | ||
if code != 0: | ||
logger.warning(f"Request error: {code}, {data}") | ||
return None | ||
else: | ||
text_base = data["payload"]["feature"]["text"] | ||
text_data = base64.b64decode(text_base) | ||
dt = np.dtype(np.float32) | ||
dt = dt.newbyteorder("<") | ||
text = np.frombuffer(text_data, dtype=dt) | ||
if len(text) > 2560: | ||
array = text[:2560] | ||
else: | ||
array = text | ||
return array | ||
|
||
|
||
class AssembleHeaderException(Exception): | ||
def __init__(self, msg: str) -> None: | ||
self.message = msg |
35 changes: 35 additions & 0 deletions
35
libs/community/tests/integration_tests/embeddings/test_sparkllm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""Test SparkLLM Text Embedding.""" | ||
from langchain_community.embeddings.sparkllm import SparkLLMTextEmbeddings | ||
|
||
|
||
def test_baichuan_embedding_documents() -> None: | ||
"""Test SparkLLM Text Embedding for documents.""" | ||
documents = [ | ||
"iFLYTEK is a well-known intelligent speech and artificial intelligence " | ||
"publicly listed company in the Asia-Pacific Region. Since its establishment," | ||
"the company is devoted to cornerstone technological research " | ||
"in speech and languages, natural language understanding, machine learning," | ||
"machine reasoning, adaptive learning, " | ||
"and has maintained the world-leading position in those " | ||
"domains. The company actively promotes the development of A.I. " | ||
"products and their sector-based " | ||
"applications, with visions of enabling machines to listen and speak, " | ||
"understand and think, " | ||
"creating a better world with artificial intelligence." | ||
] | ||
embedding = SparkLLMTextEmbeddings() | ||
output = embedding.embed_documents(documents) | ||
assert len(output) == 1 # type: ignore[arg-type] | ||
assert len(output[0]) == 2560 # type: ignore[index] | ||
|
||
|
||
def test_baichuan_embedding_query() -> None: | ||
"""Test SparkLLM Text Embedding for query.""" | ||
document = ( | ||
"iFLYTEK Open Platform was launched in 2010 by iFLYTEK as China’s " | ||
"first Artificial Intelligence open platform for Mobile Internet " | ||
"and intelligent hardware developers" | ||
) | ||
embedding = SparkLLMTextEmbeddings() | ||
output = embedding.embed_query(document) | ||
assert len(output) == 2560 # type: ignore[arg-type] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,7 @@ | |
"OCIGenAIEmbeddings", | ||
"QuantizedBiEncoderEmbeddings", | ||
"NeMoEmbeddings", | ||
"SparkLLMTextEmbeddings", | ||
] | ||
|
||
|
||
|