Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Community: add args_schema to sql_database tools for langGraph integration #18595

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions libs/community/langchain_community/tools/sql_database/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput

def _run(
self,
Expand Down Expand Up @@ -77,11 +78,16 @@ def _run(
)


class _ListSQLDataBaseToolInput(BaseModel):
tool_input: str = Field(..., description="An empty string")


class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""

name: str = "sql_db_list_tables"
description: str = "Input is an empty string, output is a comma separated list of tables in the database."
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput

def _run(
self,
Expand All @@ -92,6 +98,10 @@ def _run(
return ", ".join(self.db.get_usable_table_names())


class _QuerySQLCheckerToolInput(BaseModel):
query: str = Field(..., description="A detailed and SQL query to be checked.")


class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
Expand All @@ -104,6 +114,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with sql_db_query!
"""
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput

@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Expand Down