Skip to content

Commit

Permalink
community[patch]: add args_schema to sql_database tools for langGraph…
Browse files Browse the repository at this point in the history
… integration (langchain-ai#18595)

- **Description:** This modification adds pydantic input definition for
sql_database tools. This helps for function calling capability in
LangGraph. Since actions nodes will usually check for the args_schema
attribute on tools, This update should make these tools compatible with
it (only implemented on the InfoSQLDatabaseTool)
  - **Issue:** N/A
  - **Dependencies:** N/A
  - **Twitter handle:** juanfe8881
  • Loading branch information
juanfe88 authored and rahul-trip committed Mar 27, 2024
1 parent 011ec69 commit 9d8cf12
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions libs/community/langchain_community/tools/sql_database/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
args_schema: Type[BaseModel] = _QuerySQLDataBaseToolInput

def _run(
self,
Expand Down Expand Up @@ -77,11 +78,16 @@ def _run(
)


class _ListSQLDataBaseToolInput(BaseModel):
tool_input: str = Field(..., description="An empty string")


class ListSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
"""Tool for getting tables names."""

name: str = "sql_db_list_tables"
description: str = "Input is an empty string, output is a comma separated list of tables in the database."
args_schema: Type[BaseModel] = _ListSQLDataBaseToolInput

def _run(
self,
Expand All @@ -92,6 +98,10 @@ def _run(
return ", ".join(self.db.get_usable_table_names())


class _QuerySQLCheckerToolInput(BaseModel):
query: str = Field(..., description="A detailed and SQL query to be checked.")


class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
"""Use an LLM to check if a query is correct.
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
Expand All @@ -104,6 +114,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with sql_db_query!
"""
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput

@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Expand Down

0 comments on commit 9d8cf12

Please sign in to comment.