Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Add SparkLLM to community #17702

Merged
merged 10 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
141 changes: 141 additions & 0 deletions docs/docs/integrations/llms/sparkllm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SparkLLM\n",
"[SparkLLM](https://xinghuo.xfyun.cn/spark) is a large-scale cognitive model independently developed by iFLYTEK.\n",
"It has cross-domain knowledge and language understanding ability by learning a large amount of texts, codes and images.\n",
"It can understand and perform tasks based on natural dialogue."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prerequisite\n",
"- Get SparkLLM's app_id, api_key and api_secret from [iFlyTek SparkLLM API Console](https://console.xfyun.cn/services/bm3) (for more info, see [iFlyTek SparkLLM Intro](https://xinghuo.xfyun.cn/sparkapi) ), then set environment variables `IFLYTEK_SPARK_APP_ID`, `IFLYTEK_SPARK_API_KEY` and `IFLYTEK_SPARK_API_SECRET` or pass parameters when creating `ChatSparkLLM` as the demo above."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Use SparkLLM"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"IFLYTEK_SPARK_APP_ID\"] = \"app_id\"\n",
"os.environ[\"IFLYTEK_SPARK_API_KEY\"] = \"api_key\"\n",
"os.environ[\"IFLYTEK_SPARK_API_SECRET\"] = \"api_secret\""
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/liugddx/code/langchain/libs/core/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.7 and will be removed in 0.2.0. Use invoke instead.\n",
" warn_deprecated(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"My name is iFLYTEK Spark. How can I assist you today?\n"
]
}
],
"source": [
"from langchain_community.llms import SparkLLM\n",
"\n",
"# Load the model\n",
"llm = SparkLLM()\n",
"\n",
"res = llm(\"What's your name?\")\n",
"print(res)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-18T13:04:29.305856Z",
"start_time": "2024-02-18T13:04:28.085715Z"
}
},
"outputs": [
{
"data": {
"text/plain": "LLMResult(generations=[[Generation(text='Hello! How can I assist you today?')]], llm_output=None, run=[RunInfo(run_id=UUID('d8cdcd41-a698-4cbf-a28d-e74f9cd2037b'))])"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"res = llm.generate(prompts=[\"hello!\"])\n",
"res"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2024-02-18T13:05:44.640035Z",
"start_time": "2024-02-18T13:05:43.244126Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello! How can I assist you today?\n"
]
}
],
"source": [
"for res in llm.stream(\"foo:\"):\n",
" print(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
10 changes: 10 additions & 0 deletions libs/community/langchain_community/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,12 @@ def _import_volcengine_maas() -> Any:
return VolcEngineMaasLLM


def _import_sparkllm() -> Any:
from langchain_community.llms.sparkllm import SparkLLM

return SparkLLM


def __getattr__(name: str) -> Any:
if name == "AI21":
return _import_ai21()
Expand Down Expand Up @@ -769,6 +775,8 @@ def __getattr__(name: str) -> Any:
k: v() for k, v in get_type_to_cls_dict().items()
}
return type_to_cls_dict
elif name == "SparkLLM":
return _import_sparkllm()
else:
raise AttributeError(f"Could not find: {name}")

Expand Down Expand Up @@ -861,6 +869,7 @@ def __getattr__(name: str) -> Any:
"YandexGPT",
"Yuan2",
"VolcEngineMaasLLM",
"SparkLLM",
]


Expand Down Expand Up @@ -950,4 +959,5 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"yandex_gpt": _import_yandex_gpt,
"yuan2": _import_yuan2,
"VolcEngineMaasLLM": _import_volcengine_maas,
"SparkLLM": _import_sparkllm(),
liugddx marked this conversation as resolved.
Show resolved Hide resolved
}