forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
community[minor]: Add tidb loader support (langchain-ai#17788)
This pull request support loading data from TiDB database with Langchain. A simple usage: ``` from langchain_community.document_loaders import TiDBLoader CONNECTION_STRING = "mysql+pymysql://root@127.0.0.1:4000/test" QUERY = "select id, name, description from items;" loader = TiDBLoader( connection_string=CONNECTION_STRING, query=QUERY, page_content_columns=["name", "description"], metadata_columns=["id"], ) documents = loader.load() print(documents) ```
- Loading branch information
1 parent
f60a111
commit 725f6a6
Showing
5 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# TiDB\n", | ||
"\n", | ||
"> [TiDB](https://github.com/pingcap/tidb) is an open-source, cloud-native, distributed, MySQL-Compatible database for elastic scale and real-time analytics.\n", | ||
"\n", | ||
"This notebook introduces how to use `TiDBLoader` to load data from TiDB in langchain." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Prerequisites\n", | ||
"\n", | ||
"Before using the `TiDBLoader`, we will install the following dependencies:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install --upgrade --quiet langchain" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Then, we will configure the connection to a TiDB. In this notebook, we will follow the standard connection method provided by TiDB Cloud to establish a secure and efficient database connection." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import getpass\n", | ||
"\n", | ||
"# copy from tidb cloud console,replace it with your own\n", | ||
"tidb_connection_string_template = \"mysql+pymysql://<USER>:<PASSWORD>@<HOST>:4000/<DB>?ssl_ca=/etc/ssl/cert.pem&ssl_verify_cert=true&ssl_verify_identity=true\"\n", | ||
"tidb_password = getpass.getpass(\"Input your TiDB password:\")\n", | ||
"tidb_connection_string = tidb_connection_string_template.replace(\n", | ||
" \"<PASSWORD>\", tidb_password\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load Data from TiDB\n", | ||
"\n", | ||
"Here's a breakdown of some key arguments you can use to customize the behavior of the `TiDBLoader`:\n", | ||
"\n", | ||
"- `query` (str): This is the SQL query to be executed against the TiDB database. The query should select the data you want to load into your `Document` objects. \n", | ||
" For instance, you might use a query like `\"SELECT * FROM my_table\"` to fetch all data from `my_table`.\n", | ||
"\n", | ||
"- `page_content_columns` (Optional[List[str]]): Specifies the list of column names whose values should be included in the `page_content` of each `Document` object. \n", | ||
" If set to `None` (the default), all columns returned by the query are included in `page_content`. This allows you to tailor the content of each document based on specific columns of your data.\n", | ||
"\n", | ||
"- `metadata_columns` (Optional[List[str]]): Specifies the list of column names whose values should be included in the `metadata` of each `Document` object. \n", | ||
" By default, this list is empty, meaning no metadata will be included unless explicitly specified. This is useful for including additional information about each document that doesn't form part of the main content but is still valuable for processing or analysis." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine\n", | ||
"\n", | ||
"# Connect to the database\n", | ||
"engine = create_engine(tidb_connection_string)\n", | ||
"metadata = MetaData()\n", | ||
"table_name = \"test_tidb_loader\"\n", | ||
"\n", | ||
"# Create a table\n", | ||
"test_table = Table(\n", | ||
" table_name,\n", | ||
" metadata,\n", | ||
" Column(\"id\", Integer, primary_key=True),\n", | ||
" Column(\"name\", String(255)),\n", | ||
" Column(\"description\", String(255)),\n", | ||
")\n", | ||
"metadata.create_all(engine)\n", | ||
"\n", | ||
"\n", | ||
"with engine.connect() as connection:\n", | ||
" transaction = connection.begin()\n", | ||
" try:\n", | ||
" connection.execute(\n", | ||
" test_table.insert(),\n", | ||
" [\n", | ||
" {\"name\": \"Item 1\", \"description\": \"Description of Item 1\"},\n", | ||
" {\"name\": \"Item 2\", \"description\": \"Description of Item 2\"},\n", | ||
" {\"name\": \"Item 3\", \"description\": \"Description of Item 3\"},\n", | ||
" ],\n", | ||
" )\n", | ||
" transaction.commit()\n", | ||
" except:\n", | ||
" transaction.rollback()\n", | ||
" raise" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"------------------------------\n", | ||
"content: name: Item 1\n", | ||
"description: Description of Item 1\n", | ||
"metada: {'id': 1}\n", | ||
"------------------------------\n", | ||
"content: name: Item 2\n", | ||
"description: Description of Item 2\n", | ||
"metada: {'id': 2}\n", | ||
"------------------------------\n", | ||
"content: name: Item 3\n", | ||
"description: Description of Item 3\n", | ||
"metada: {'id': 3}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from langchain_community.document_loaders import TiDBLoader\n", | ||
"\n", | ||
"# Setup TiDBLoader to retrieve data\n", | ||
"loader = TiDBLoader(\n", | ||
" connection_string=tidb_connection_string,\n", | ||
" query=f\"SELECT * FROM {table_name};\",\n", | ||
" page_content_columns=[\"name\", \"description\"],\n", | ||
" metadata_columns=[\"id\"],\n", | ||
")\n", | ||
"\n", | ||
"# Load data\n", | ||
"documents = loader.load()\n", | ||
"\n", | ||
"# Display the loaded documents\n", | ||
"for doc in documents:\n", | ||
" print(\"-\" * 30)\n", | ||
" print(f\"content: {doc.page_content}\\nmetada: {doc.metadata}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"test_table.drop(bind=engine)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "langchain", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
libs/community/langchain_community/document_loaders/tidb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from typing import Any, Dict, Iterator, List, Optional | ||
|
||
from langchain_core.documents import Document | ||
|
||
from langchain_community.document_loaders.base import BaseLoader | ||
|
||
|
||
class TiDBLoader(BaseLoader): | ||
"""Load documents from TiDB.""" | ||
|
||
def __init__( | ||
self, | ||
connection_string: str, | ||
query: str, | ||
page_content_columns: Optional[List[str]] = None, | ||
metadata_columns: Optional[List[str]] = None, | ||
engine_args: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
"""Initialize TiDB document loader. | ||
Args: | ||
connection_string (str): The connection string for the TiDB database, | ||
format: "mysql+pymysql://root@127.0.0.1:4000/test". | ||
query: The query to run in TiDB. | ||
page_content_columns: Optional. Columns written to Document `page_content`, | ||
default(None) to all columns. | ||
metadata_columns: Optional. Columns written to Document `metadata`, | ||
default(None) to no columns. | ||
engine_args: Optional. Additional arguments to pass to sqlalchemy engine. | ||
""" | ||
self.connection_string = connection_string | ||
self.query = query | ||
self.page_content_columns = page_content_columns | ||
self.metadata_columns = metadata_columns if metadata_columns is not None else [] | ||
self.engine_args = engine_args | ||
|
||
def lazy_load(self) -> Iterator[Document]: | ||
"""Lazy load TiDB data into document objects.""" | ||
|
||
from sqlalchemy import create_engine | ||
from sqlalchemy.engine import Engine | ||
from sqlalchemy.sql import text | ||
|
||
# use sqlalchemy to create db connection | ||
engine: Engine = create_engine( | ||
self.connection_string, **(self.engine_args or {}) | ||
) | ||
|
||
# execute query | ||
with engine.connect() as conn: | ||
result = conn.execute(text(self.query)) | ||
|
||
# convert result to Document objects | ||
column_names = list(result.keys()) | ||
for row in result: | ||
# convert row to dict{column:value} | ||
row_data = { | ||
column_names[index]: value for index, value in enumerate(row) | ||
} | ||
page_content = "\n".join( | ||
f"{k}: {v}" | ||
for k, v in row_data.items() | ||
if self.page_content_columns is None | ||
or k in self.page_content_columns | ||
) | ||
metadata = {col: row_data[col] for col in self.metadata_columns} | ||
yield Document(page_content=page_content, metadata=metadata) | ||
|
||
def load(self) -> List[Document]: | ||
"""Load TiDB data into document objects.""" | ||
return list(self.lazy_load()) |
76 changes: 76 additions & 0 deletions
76
libs/community/tests/integration_tests/document_loaders/test_tidb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import os | ||
|
||
import pytest | ||
from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine | ||
|
||
from langchain_community.document_loaders import TiDBLoader | ||
|
||
try: | ||
CONNECTION_STRING = os.getenv("TEST_TiDB_CONNECTION_URL", "") | ||
|
||
if CONNECTION_STRING == "": | ||
raise OSError("TEST_TiDB_URL environment variable is not set") | ||
|
||
tidb_available = True | ||
except (OSError, ImportError): | ||
tidb_available = False | ||
|
||
|
||
@pytest.mark.skipif(not tidb_available, reason="tidb is not available") | ||
def test_load_documents() -> None: | ||
"""Test loading documents from TiDB.""" | ||
|
||
# Connect to the database | ||
engine = create_engine(CONNECTION_STRING) | ||
metadata = MetaData() | ||
table_name = "tidb_loader_intergration_test" | ||
|
||
# Create a test table | ||
test_table = Table( | ||
table_name, | ||
metadata, | ||
Column("id", Integer, primary_key=True), | ||
Column("name", String(255)), | ||
Column("description", String(255)), | ||
) | ||
metadata.create_all(engine) | ||
|
||
with engine.connect() as connection: | ||
transaction = connection.begin() | ||
try: | ||
connection.execute( | ||
test_table.insert(), | ||
[ | ||
{"name": "Item 1", "description": "Description of Item 1"}, | ||
{"name": "Item 2", "description": "Description of Item 2"}, | ||
{"name": "Item 3", "description": "Description of Item 3"}, | ||
], | ||
) | ||
transaction.commit() | ||
except: | ||
transaction.rollback() | ||
raise | ||
|
||
loader = TiDBLoader( | ||
connection_string=CONNECTION_STRING, | ||
query=f"SELECT * FROM {table_name};", | ||
page_content_columns=["name", "description"], | ||
metadata_columns=["id"], | ||
) | ||
documents = loader.load() | ||
test_table.drop(bind=engine) | ||
|
||
# check | ||
assert len(documents) == 3 | ||
assert ( | ||
documents[0].page_content == "name: Item 1\ndescription: Description of Item 1" | ||
) | ||
assert documents[0].metadata == {"id": 1} | ||
assert ( | ||
documents[1].page_content == "name: Item 2\ndescription: Description of Item 2" | ||
) | ||
assert documents[1].metadata == {"id": 2} | ||
assert ( | ||
documents[2].page_content == "name: Item 3\ndescription: Description of Item 3" | ||
) | ||
assert documents[2].metadata == {"id": 3} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters