Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Add SparkLLM Text Embedding Model and SparkLLM introduction #17573

Merged
merged 15 commits into from
Feb 20, 2024
11 changes: 11 additions & 0 deletions docs/docs/integrations/providers/sparkllm.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SparkLLM

>[SparkLLM](https://xinghuo.xfyun.cn/spark) is a large-scale cognitive model independently developed by iFLYTEK.
It has cross-domain knowledge and language understanding ability by learning a large amount of texts, codes and images.
It can understand and perform tasks based on natural dialogue.

## SparkLLM Chat Model
An example is available at [example](/docs/integrations/chat/sparkllm).

## SparkLLM Text Embedding Model
An example is available at [example](/docs/integrations/text_embedding/sparkllm)
90 changes: 90 additions & 0 deletions docs/docs/integrations/text_embedding/sparkllm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SparkLLM Text Embeddings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Official Website: https://www.xfyun.cn/doc/spark/Embedding_new_api.html\n",
"\n",
"An API key is required to use this embedding model. You can get one by registering at https://platform.SparkLLM-ai.com/docs/text-Embedding."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"SparkLLMTextEmbeddings support 2K token window and preduces vectors with 2560 dimensions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings import SparkLLMTextEmbeddings\n",
"\n",
"embeddings = SparkLLMTextEmbeddings(\n",
" spark_app_id=\"sk-*\", spark_api_key=\"\", spark_api_secret=\"\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, you can set API key this way:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"SPARK_APP_ID\"] = \"YOUR_APP_ID\"\n",
"os.environ[\"SPARK_API_KEY\"] = \"YOUR_API_KEY\"\n",
"os.environ[\"SPARK_API_SECRET\"] = \"YOUR_API_SECRET\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_1 = \"iFLYTEK is a well-known intelligent speech and artificial intelligence publicly listed company in the Asia-Pacific Region. Since its establishment, the company is devoted to cornerstone technological research in speech and languages, natural language understanding, machine learning, machine reasoning, adaptive learning, and has maintained the world-leading position in those domains. The company actively promotes the development of A.I. products and their sector-based applications, with visions of enabling machines to listen and speak, understand and think, creating a better world with artificial intelligence.\"\n",
"text_2 = \"iFLYTEK Open Platform was launched in 2010 by iFLYTEK as China’s first Artificial Intelligence open platform for Mobile Internet and intelligent hardware developers.\"\n",
"\n",
"query_result = embeddings.embed_query(text_2)\n",
"query_result"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"doc_result = embeddings.embed_documents([text_1, text_2])\n",
"doc_result"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
SentenceTransformerEmbeddings,
)
from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings
from langchain_community.embeddings.sparkllm import SparkLLMTextEmbeddings
from langchain_community.embeddings.tensorflow_hub import TensorflowHubEmbeddings
from langchain_community.embeddings.vertexai import VertexAIEmbeddings
from langchain_community.embeddings.volcengine import VolcanoEmbeddings
Expand Down Expand Up @@ -152,6 +153,7 @@
"OCIGenAIEmbeddings",
"QuantizedBiEncoderEmbeddings",
"NeMoEmbeddings",
"SparkLLMTextEmbeddings",
]


Expand Down
184 changes: 184 additions & 0 deletions libs/community/langchain_community/embeddings/sparkllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import base64
import hashlib
import hmac
import json
import logging
from datetime import datetime
from time import mktime
from typing import Any, Dict, List, Optional
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time

import numpy as np
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from numpy import ndarray

# Used for document and knowledge embedding
EMBEDDING_P_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/sa8a05c27"
# Used for user questions embedding
EMBEDDING_Q_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/s50d55a16"

# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/).

# Official Website: https://www.xfyun.cn/doc/spark/Embedding_new_api.html
# Developers need to create an application in the console first, use the appid, APIKey,
# and APISecret provided in the application for authentication,
# and generate an authentication URL for handshake.
# You can get one by registering at https://console.xfyun.cn/services/bm3.
# SparkLLMTextEmbeddings support 2K token window and preduces vectors with
# 2560 dimensions.

logger = logging.getLogger(__name__)


class Url:
def __init__(self, host: str, path: str, schema: str) -> None:
self.host = host
self.path = path
self.schema = schema
pass


class SparkLLMTextEmbeddings(BaseModel, Embeddings):
"""SparkLLM Text Embedding models."""

spark_app_id: SecretStr
spark_api_key: SecretStr
spark_api_secret: SecretStr

@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that auth token exists in environment."""
cls.spark_app_id = convert_to_secret_str(
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID")
)
cls.spark_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY")
)
cls.spark_api_secret = convert_to_secret_str(
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET")
)
return values

def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]:
url = self._assemble_ws_auth_url(
request_url=host,
method="POST",
api_key=self.spark_api_key.get_secret_value(),
api_secret=self.spark_api_secret.get_secret_value(),
)
content = self._get_body(self.spark_app_id.get_secret_value(), texts)
response = requests.post(
url, json=content, headers={"content-type": "application/json"}
).text
res_arr = self._parser_message(response)
if res_arr is not None:
return res_arr.tolist()
return None

def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
"""Public method to get embeddings for a list of documents.

Args:
texts: The list of texts to embed.

Returns:
A list of embeddings, one for each text, or None if an error occurs.
"""
return self._embed(texts, EMBEDDING_P_API_URL)

def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override]
"""Public method to get embedding for a single query text.

Args:
text: The text to embed.

Returns:
Embeddings for the text, or None if an error occurs.
"""
result = self._embed([text], EMBEDDING_Q_API_URL)
return result[0] if result is not None else None

@staticmethod
def _assemble_ws_auth_url(
request_url: str, method: str = "GET", api_key: str = "", api_secret: str = ""
) -> str:
u = SparkLLMTextEmbeddings._parse_url(request_url)
host = u.host
path = u.path
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
host, date, method, path
)
signature_sha = hmac.new(
api_secret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()
signature_sha_str = base64.b64encode(signature_sha).decode(encoding="utf-8")
authorization_origin = (
'api_key="%s", algorithm="%s", headers="%s", signature="%s"'
% (api_key, "hmac-sha256", "host date request-line", signature_sha_str)
)
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8"
)
values = {"host": host, "date": date, "authorization": authorization}

return request_url + "?" + urlencode(values)

@staticmethod
def _parse_url(request_url: str) -> Url:
stidx = request_url.index("://")
host = request_url[stidx + 3 :]
schema = request_url[: stidx + 3]
edidx = host.index("/")
if edidx <= 0:
raise AssembleHeaderException("invalid request url:" + request_url)
path = host[edidx:]
host = host[:edidx]
u = Url(host, path, schema)
return u

@staticmethod
def _get_body(appid: str, text: List[str]) -> Dict[str, Any]:
body = {
"header": {"app_id": appid, "uid": "39769795890", "status": 3},
"parameter": {"emb": {"feature": {"encoding": "utf8"}}},
"payload": {
"messages": {
"text": base64.b64encode(json.dumps(text).encode("utf-8")).decode()
}
},
}
return body

@staticmethod
def _parser_message(
message: str,
) -> Optional[ndarray]:
data = json.loads(message)
code = data["header"]["code"]
if code != 0:
logger.warning(f"Request error: {code}, {data}")
return None
else:
text_base = data["payload"]["feature"]["text"]
text_data = base64.b64decode(text_base)
dt = np.dtype(np.float32)
dt = dt.newbyteorder("<")
text = np.frombuffer(text_data, dtype=dt)
if len(text) > 2560:
array = text[:2560]
else:
array = text
return array


class AssembleHeaderException(Exception):
def __init__(self, msg: str) -> None:
self.message = msg
35 changes: 35 additions & 0 deletions libs/community/tests/integration_tests/embeddings/test_sparkllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Test SparkLLM Text Embedding."""
from langchain_community.embeddings.sparkllm import SparkLLMTextEmbeddings


def test_baichuan_embedding_documents() -> None:
"""Test SparkLLM Text Embedding for documents."""
documents = [
"iFLYTEK is a well-known intelligent speech and artificial intelligence "
"publicly listed company in the Asia-Pacific Region. Since its establishment,"
"the company is devoted to cornerstone technological research "
"in speech and languages, natural language understanding, machine learning,"
"machine reasoning, adaptive learning, "
"and has maintained the world-leading position in those "
"domains. The company actively promotes the development of A.I. "
"products and their sector-based "
"applications, with visions of enabling machines to listen and speak, "
"understand and think, "
"creating a better world with artificial intelligence."
]
embedding = SparkLLMTextEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1 # type: ignore[arg-type]
assert len(output[0]) == 2560 # type: ignore[index]


def test_baichuan_embedding_query() -> None:
"""Test SparkLLM Text Embedding for query."""
document = (
"iFLYTEK Open Platform was launched in 2010 by iFLYTEK as China’s "
"first Artificial Intelligence open platform for Mobile Internet "
"and intelligent hardware developers"
)
embedding = SparkLLMTextEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 2560 # type: ignore[arg-type]
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"OCIGenAIEmbeddings",
"QuantizedBiEncoderEmbeddings",
"NeMoEmbeddings",
"SparkLLMTextEmbeddings",
]


Expand Down