Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Add you.com tool, add async to retriever, add async testing, add You tool doc #18032

Merged
merged 14 commits into from
Mar 3, 2024
265 changes: 265 additions & 0 deletions docs/docs/integrations/tools/you.ipynb

Large diffs are not rendered by default.

19 changes: 17 additions & 2 deletions libs/community/langchain_community/retrievers/you.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Any, List

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from langchain_community.utilities import YouSearchAPIWrapper


class YouRetriever(BaseRetriever, YouSearchAPIWrapper):
class YouRetriever(BaseRetriever, YouSearchAPIWrapper): # type: ignore[valid-type,misc]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the type hint to ignore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hwchase17 development artifact. Removed

thanks

"""`You` retriever that uses You.com's search API.
It wraps results() to get_relevant_documents
It uses all YouSearchAPIWrapper arguments without any change.
Expand All @@ -21,3 +24,15 @@ def _get_relevant_documents(
**kwargs: Any,
) -> List[Document]:
return self.results(query, run_manager=run_manager.get_child(), **kwargs)

async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
results = await self.results_async(
query, run_manager=run_manager.get_child(), **kwargs
)
return results
9 changes: 9 additions & 0 deletions libs/community/langchain_community/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,12 @@ def _import_yahoo_finance_news() -> Any:
return YahooFinanceNewsTool


def _import_you_tool() -> Any:
from langchain_community.tools.you.tool import YouSearchTool

return YouSearchTool


def _import_youtube_search() -> Any:
from langchain_community.tools.youtube.search import YouTubeSearchTool

Expand Down Expand Up @@ -1047,6 +1053,8 @@ def __getattr__(name: str) -> Any:
return _import_wolfram_alpha_tool()
elif name == "YahooFinanceNewsTool":
return _import_yahoo_finance_news()
elif name == "YouSearchTool":
return _import_you_tool()
elif name == "YouTubeSearchTool":
return _import_youtube_search()
elif name == "ZapierNLAListActions":
Expand Down Expand Up @@ -1183,6 +1191,7 @@ def __getattr__(name: str) -> Any:
"WolframAlphaQueryRun",
"WriteFileTool",
"YahooFinanceNewsTool",
"YouSearchTool",
"YouTubeSearchTool",
"ZapierNLAListActions",
"ZapierNLARunAction",
Expand Down
8 changes: 8 additions & 0 deletions libs/community/langchain_community/tools/you/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""You.com API toolkit."""


from langchain_community.tools.you.tool import YouSearchTool

__all__ = [
"YouSearchTool",
]
43 changes: 43 additions & 0 deletions libs/community/langchain_community/tools/you/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import List, Optional, Type

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool

from langchain_community.utilities.you import YouSearchAPIWrapper


class YouInput(BaseModel):
query: str = Field(description="should be a search query")


class YouSearchTool(BaseTool):
"""Tool that searches the you.com API"""

name = "you_search"
description = (
"The YOU APIs make LLMs and search experiences more factual and"
"up to date with realtime web data."
)
args_schema: Type[BaseModel] = YouInput
api_wrapper: YouSearchAPIWrapper = Field(default_factory=YouSearchAPIWrapper)

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> List[Document]:
"""Use the you.com tool."""
return self.api_wrapper.results(query)

async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> List[Document]:
"""Use the you.com tool asynchronously."""
return await self.api_wrapper.results_async(query)
79 changes: 41 additions & 38 deletions libs/community/langchain_community/utilities/you.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

In order to set this up, follow instructions at:
"""
import json
from typing import Any, Dict, List, Literal, Optional

import aiohttp
Expand Down Expand Up @@ -113,16 +112,16 @@ def _parse_results(self, raw_search_results: Dict) -> List[Document]:

docs = []
for hit in raw_search_results["hits"]:
n_snippets_per_hit = self.n_snippets_per_hit or len(hit["snippets"])
for snippet in hit["snippets"][:n_snippets_per_hit]:
n_snippets_per_hit = self.n_snippets_per_hit or len(hit.get("snippets"))
for snippet in hit.get("snippets")[:n_snippets_per_hit]:
docs.append(
Document(
page_content=snippet,
metadata={
"url": hit["url"],
"thumbnail_url": hit["thumbnail_url"],
"title": hit["title"],
"description": hit["description"],
"url": hit.get("url"),
"thumbnail_url": hit.get("thumbnail_url"),
"title": hit.get("title"),
"description": hit.get("description"),
},
)
)
Expand Down Expand Up @@ -188,43 +187,47 @@ def results(
async def raw_results_async(
self,
query: str,
num_web_results: Optional[int] = 5,
safesearch: Optional[str] = "moderate",
country: Optional[str] = "US",
**kwargs: Any,
) -> Dict:
"""Get results from the you.com Search API asynchronously."""

# Function to perform the API call
async def fetch() -> str:
params = {
"query": query,
"num_web_results": num_web_results,
"safesearch": safesearch,
"country": country,
}
async with aiohttp.ClientSession() as session:
async with session.post(f"{YOU_API_URL}/search", json=params) as res:
if res.status == 200:
data = await res.text()
return data
else:
raise Exception(f"Error {res.status}: {res.reason}")

results_json_str = await fetch()
return json.loads(results_json_str)
headers = {"X-API-Key": self.ydc_api_key or ""}
params = {
"query": query,
"num_web_results": self.num_web_results,
"safesearch": self.safesearch,
"country": self.country,
**kwargs,
}
params = {k: v for k, v in params.items() if v is not None}
# news endpoint expects `q` instead of `query`
if self.endpoint_type == "news":
params["q"] = params["query"]
del params["query"]

# @todo deprecate `snippet`, not part of API
if self.endpoint_type == "snippet":
self.endpoint_type = "search"

async with aiohttp.ClientSession() as session:
async with session.get(
url=f"{YOU_API_URL}/{self.endpoint_type}",
params=params,
headers=headers,
) as res:
if res.status == 200:
results = await res.json()
return results
else:
raise Exception(f"Error {res.status}: {res.reason}")

async def results_async(
self,
query: str,
num_web_results: Optional[int] = 5,
safesearch: Optional[str] = "moderate",
country: Optional[str] = "US",
**kwargs: Any,
) -> List[Document]:
results_json = await self.raw_results_async(
query=query,
num_web_results=num_web_results,
safesearch=safesearch,
country=country,
raw_search_results_async = await self.raw_results_async(
query,
**{key: value for key, value in kwargs.items() if value is not None},
)

return self._parse_results(results_json["results"])
return self._parse_results(raw_search_results_async)
39 changes: 39 additions & 0 deletions libs/community/tests/unit_tests/retrievers/test_you.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from unittest.mock import AsyncMock, patch

import pytest
import responses

from langchain_community.retrievers.you import YouRetriever
Expand Down Expand Up @@ -70,3 +73,39 @@ def test_invoke_news(self) -> None:
results = you_wrapper.results(query)
expected_result = NEWS_RESPONSE_PARSED
assert results == expected_result

@pytest.mark.asyncio
async def test_aget_relevant_documents(self) -> None:
instance = YouRetriever(ydc_api_key="test_api_key")

# Mock response object to simulate aiohttp response
mock_response = AsyncMock()
mock_response.__aenter__.return_value = (
mock_response # Make the context manager return itself
)
mock_response.__aexit__.return_value = None # No value needed for exit
mock_response.status = 200
mock_response.json = AsyncMock(return_value=MOCK_RESPONSE_RAW)

# Patch the aiohttp.ClientSession object
with patch("aiohttp.ClientSession.get", return_value=mock_response):
results = await instance.aget_relevant_documents("test query")
assert results == MOCK_PARSED_OUTPUT

@pytest.mark.asyncio
async def test_ainvoke(self) -> None:
instance = YouRetriever(ydc_api_key="test_api_key")

# Mock response object to simulate aiohttp response
mock_response = AsyncMock()
mock_response.__aenter__.return_value = (
mock_response # Make the context manager return itself
)
mock_response.__aexit__.return_value = None # No value needed for exit
mock_response.status = 200
mock_response.json = AsyncMock(return_value=MOCK_RESPONSE_RAW)

# Patch the aiohttp.ClientSession object
with patch("aiohttp.ClientSession.get", return_value=mock_response):
results = await instance.ainvoke("test query")
assert results == MOCK_PARSED_OUTPUT
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/tools/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
"WolframAlphaQueryRun",
"WriteFileTool",
"YahooFinanceNewsTool",
"YouSearchTool",
"YouTubeSearchTool",
"ZapierNLAListActions",
"ZapierNLARunAction",
Expand Down
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/tools/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
"WolframAlphaQueryRun",
"WriteFileTool",
"YahooFinanceNewsTool",
"YouSearchTool",
"YouTubeSearchTool",
"ZapierNLAListActions",
"ZapierNLARunAction",
Expand Down
87 changes: 87 additions & 0 deletions libs/community/tests/unit_tests/tools/test_you.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from unittest.mock import AsyncMock, patch

import pytest
import responses

from langchain_community.tools.you import YouSearchTool
from langchain_community.utilities.you import YouSearchAPIWrapper

from ..utilities.test_you import (
LIMITED_PARSED_OUTPUT,
MOCK_PARSED_OUTPUT,
MOCK_RESPONSE_RAW,
NEWS_RESPONSE_PARSED,
NEWS_RESPONSE_RAW,
TEST_ENDPOINT,
)


class TestYouSearchTool:
@responses.activate
def test_invoke(self) -> None:
responses.add(
responses.GET, f"{TEST_ENDPOINT}/search", json=MOCK_RESPONSE_RAW, status=200
)
query = "Test query text"
you_tool = YouSearchTool(api_wrapper=YouSearchAPIWrapper(ydc_api_key="test"))
results = you_tool.invoke(query)
expected_result = MOCK_PARSED_OUTPUT
assert results == expected_result

@responses.activate
def test_invoke_max_docs(self) -> None:
responses.add(
responses.GET, f"{TEST_ENDPOINT}/search", json=MOCK_RESPONSE_RAW, status=200
)
query = "Test query text"
you_tool = YouSearchTool(
api_wrapper=YouSearchAPIWrapper(ydc_api_key="test", k=2)
)
results = you_tool.invoke(query)
expected_result = [MOCK_PARSED_OUTPUT[0], MOCK_PARSED_OUTPUT[1]]
assert results == expected_result

@responses.activate
def test_invoke_limit_snippets(self) -> None:
responses.add(
responses.GET, f"{TEST_ENDPOINT}/search", json=MOCK_RESPONSE_RAW, status=200
)
query = "Test query text"
you_tool = YouSearchTool(
api_wrapper=YouSearchAPIWrapper(ydc_api_key="test", n_snippets_per_hit=1)
)
results = you_tool.invoke(query)
expected_result = LIMITED_PARSED_OUTPUT
assert results == expected_result

@responses.activate
def test_invoke_news(self) -> None:
responses.add(
responses.GET, f"{TEST_ENDPOINT}/news", json=NEWS_RESPONSE_RAW, status=200
)

query = "Test news text"
you_tool = YouSearchTool(
api_wrapper=YouSearchAPIWrapper(ydc_api_key="test", endpoint_type="news")
)
results = you_tool.invoke(query)
expected_result = NEWS_RESPONSE_PARSED
assert results == expected_result

@pytest.mark.asyncio
async def test_ainvoke(self) -> None:
you_tool = YouSearchTool(api_wrapper=YouSearchAPIWrapper(ydc_api_key="test"))

# Mock response object to simulate aiohttp response
mock_response = AsyncMock()
mock_response.__aenter__.return_value = (
mock_response # Make the context manager return itself
)
mock_response.__aexit__.return_value = None # No value needed for exit
mock_response.status = 200
mock_response.json = AsyncMock(return_value=MOCK_RESPONSE_RAW)

# Patch the aiohttp.ClientSession object
with patch("aiohttp.ClientSession.get", return_value=mock_response):
results = await you_tool.ainvoke("test query")
assert results == MOCK_PARSED_OUTPUT