Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain: upgrade mypy #19163

Merged
merged 9 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def create_prompt(
]
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]

@classmethod
def from_llm_and_tools(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def create_prompt(
HumanMessagePromptTemplate.from_template(final_prompt),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]

def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/load_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def _get_eleven_labs_text2speech(**kwargs: Any) -> BaseTool:


def _get_memorize(llm: BaseLanguageModel, **kwargs: Any) -> BaseTool:
return Memorize(llm=llm)
return Memorize(llm=llm) # type: ignore[arg-type]


def _get_google_cloud_texttospeech(**kwargs: Any) -> BaseTool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def create_prompt(
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
return ChatPromptTemplate(messages=messages)
return ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]

@classmethod
def from_llm_and_tools(
Expand All @@ -220,7 +220,7 @@ def from_llm_and_tools(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
return cls( # type: ignore[call-arg]
llm=llm,
prompt=prompt,
tools=tools,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def create_prompt(
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
return ChatPromptTemplate(messages=messages)
return ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]

@classmethod
def from_llm_and_tools(
Expand All @@ -298,7 +298,7 @@ def from_llm_and_tools(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
return cls( # type: ignore[call-arg]
llm=llm,
prompt=prompt,
tools=tools,
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/structured_chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def create_prompt(
*_memory_prompts,
HumanMessagePromptTemplate.from_template(human_message_template),
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
return ChatPromptTemplate(input_variables=input_variables, messages=messages) # type: ignore[arg-type]

@classmethod
def from_llm_and_tools(
Expand Down
5 changes: 3 additions & 2 deletions libs/langchain/langchain/chains/graph_qa/cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,11 @@ def from_llm(
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
)

qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs)
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type]

cypher_generation_chain = LLMChain(
llm=cypher_llm or llm, **use_cypher_llm_kwargs
llm=cypher_llm or llm, # type: ignore[arg-type]
**use_cypher_llm_kwargs, # type: ignore[arg-type]
)

if exclude_types and include_types:
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/graph_qa/neptune_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def from_llm(
)
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)

return cls(
return cls( # type: ignore[call-arg]
qa_chain=qa_chain,
sparql_generation_chain=sparql_generation_chain,
examples=examples,
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/llm_checker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _load_question_to_checked_assertions_chain(
revised_answer_chain,
]
question_to_checked_assertions_chain = SequentialChain(
chains=chains,
chains=chains, # type: ignore[arg-type]
input_variables=["question"],
output_variables=["revised_statement"],
verbose=True,
Expand Down
46 changes: 24 additions & 22 deletions libs/langchain/langchain/chains/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def _load_hyde_chain(config: dict, **kwargs: Any) -> HypotheticalDocumentEmbedde
else:
raise ValueError("`embeddings` must be present.")
return HypotheticalDocumentEmbedder(
llm_chain=llm_chain, base_embeddings=embeddings, **config
llm_chain=llm_chain, # type: ignore[arg-type]
base_embeddings=embeddings,
**config, # type: ignore[arg-type]
)


Expand Down Expand Up @@ -125,7 +127,7 @@ def _load_map_reduce_documents_chain(

return MapReduceDocumentsChain(
llm_chain=llm_chain,
reduce_documents_chain=reduce_documents_chain,
reduce_documents_chain=reduce_documents_chain, # type: ignore[arg-type]
**config,
)

Expand Down Expand Up @@ -207,7 +209,7 @@ def _load_llm_bash_chain(config: dict, **kwargs: Any) -> Any:
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
if llm_chain:
return LLMBashChain(llm_chain=llm_chain, prompt=prompt, **config)
return LLMBashChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
else:
return LLMBashChain(llm=llm, prompt=prompt, **config)

Expand Down Expand Up @@ -250,10 +252,10 @@ def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain:
revised_answer_prompt = load_prompt(config.pop("revised_answer_prompt_path"))
return LLMCheckerChain(
llm=llm,
create_draft_answer_prompt=create_draft_answer_prompt,
list_assertions_prompt=list_assertions_prompt,
check_assertions_prompt=check_assertions_prompt,
revised_answer_prompt=revised_answer_prompt,
create_draft_answer_prompt=create_draft_answer_prompt, # type: ignore[arg-type]
list_assertions_prompt=list_assertions_prompt, # type: ignore[arg-type]
check_assertions_prompt=check_assertions_prompt, # type: ignore[arg-type]
revised_answer_prompt=revised_answer_prompt, # type: ignore[arg-type]
**config,
)

Expand Down Expand Up @@ -281,7 +283,7 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
elif "prompt_path" in config:
prompt = load_prompt(config.pop("prompt_path"))
if llm_chain:
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config)
return LLMMathChain(llm_chain=llm_chain, prompt=prompt, **config) # type: ignore[arg-type]
else:
return LLMMathChain(llm=llm, prompt=prompt, **config)

Expand All @@ -296,7 +298,7 @@ def _load_map_rerank_documents_chain(
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
return MapRerankDocumentsChain(llm_chain=llm_chain, **config)
return MapRerankDocumentsChain(llm_chain=llm_chain, **config) # type: ignore[arg-type]


def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
Expand All @@ -309,7 +311,7 @@ def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
llm_chain = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
return PALChain(llm_chain=llm_chain, **config)
return PALChain(llm_chain=llm_chain, **config) # type: ignore[arg-type]


def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocumentsChain:
Expand Down Expand Up @@ -337,8 +339,8 @@ def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocuments
elif "document_prompt_path" in config:
document_prompt = load_prompt(config.pop("document_prompt_path"))
return RefineDocumentsChain(
initial_llm_chain=initial_llm_chain,
refine_llm_chain=refine_llm_chain,
initial_llm_chain=initial_llm_chain, # type: ignore[arg-type]
refine_llm_chain=refine_llm_chain, # type: ignore[arg-type]
document_prompt=document_prompt,
**config,
)
Expand All @@ -355,7 +357,7 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
"One of `combine_documents_chain` or "
"`combine_documents_chain_path` must be present."
)
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config)
return QAWithSourcesChain(combine_documents_chain=combine_documents_chain, **config) # type: ignore[arg-type]


def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
Expand All @@ -368,7 +370,7 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain")
chain = load_chain_from_config(llm_chain_config)
return SQLDatabaseChain(llm_chain=chain, database=database, **config)
return SQLDatabaseChain(llm_chain=chain, database=database, **config) # type: ignore[arg-type]
if "llm" in config:
llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config)
Expand Down Expand Up @@ -403,7 +405,7 @@ def _load_vector_db_qa_with_sources_chain(
"`combine_documents_chain_path` must be present."
)
return VectorDBQAWithSourcesChain(
combine_documents_chain=combine_documents_chain,
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
vectorstore=vectorstore,
**config,
)
Expand All @@ -425,7 +427,7 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
"`combine_documents_chain_path` must be present."
)
return RetrievalQA(
combine_documents_chain=combine_documents_chain,
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
retriever=retriever,
**config,
)
Expand All @@ -449,7 +451,7 @@ def _load_retrieval_qa_with_sources_chain(
"`combine_documents_chain_path` must be present."
)
return RetrievalQAWithSourcesChain(
combine_documents_chain=combine_documents_chain,
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
retriever=retriever,
**config,
)
Expand All @@ -471,7 +473,7 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
"`combine_documents_chain_path` must be present."
)
return VectorDBQA(
combine_documents_chain=combine_documents_chain,
combine_documents_chain=combine_documents_chain, # type: ignore[arg-type]
vectorstore=vectorstore,
**config,
)
Expand All @@ -495,8 +497,8 @@ def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:

return GraphCypherQAChain(
graph=graph,
cypher_generation_chain=cypher_generation_chain,
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain, # type: ignore[arg-type]
qa_chain=qa_chain, # type: ignore[arg-type]
**config,
)

Expand Down Expand Up @@ -525,8 +527,8 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
else:
raise ValueError("`requests_wrapper` must be present.")
return APIChain(
api_request_chain=api_request_chain,
api_answer_chain=api_answer_chain,
api_request_chain=api_request_chain, # type: ignore[arg-type]
api_answer_chain=api_answer_chain, # type: ignore[arg-type]
requests_wrapper=requests_wrapper,
**config,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
)
),
]
prompt = ChatPromptTemplate(messages=messages)
prompt = ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]

chain = LLMChain(
llm=llm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def create_qa_with_structure_chain(
HumanMessagePromptTemplate.from_template("Question: {question}"),
HumanMessage(content="Tips: Make sure to answer in the correct format"),
]
prompt = prompt or ChatPromptTemplate(messages=messages)
prompt = prompt or ChatPromptTemplate(messages=messages) # type: ignore[arg-type, call-arg]

chain = LLMChain(
llm=llm,
Expand Down
22 changes: 11 additions & 11 deletions libs/langchain/langchain/chains/qa_with_sources/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def _load_stuff_chain(
verbose: Optional[bool] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose) # type: ignore[arg-type]
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
document_prompt=document_prompt,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
**kwargs,
)

Expand All @@ -83,14 +83,14 @@ def _load_map_reduce_chain(
token_max: int = 3000,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) # type: ignore[arg-type]
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
)
if collapse_prompt is None:
collapse_chain = None
Expand All @@ -105,7 +105,7 @@ def _load_map_reduce_chain(
llm_chain=LLMChain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
),
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
Expand All @@ -114,13 +114,13 @@ def _load_map_reduce_chain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
token_max=token_max,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
**kwargs,
)

Expand All @@ -136,16 +136,16 @@ def _load_refine_chain(
verbose: Optional[bool] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose) # type: ignore[arg-type]
_refine_llm = refine_llm or llm
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)
refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose) # type: ignore[arg-type]
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
document_prompt=document_prompt,
verbose=verbose,
verbose=verbose, # type: ignore[arg-type]
**kwargs,
)

Expand Down