Skip to content

Commit

Permalink
community: Add you.com tool, add async to retriever, add async testin…
Browse files Browse the repository at this point in the history
…g, add You tool doc (langchain-ai#18032)

- **Description:** finishes adding the you.com functionality including:
    - add async functions to utility and retriever
    - add the You.com Tool
    - add async testing for utility, retriever, and tool
    - add a tool integration notebook page
- **Dependencies:** any dependencies required for this change
- **Twitter handle:** @scottnath
  • Loading branch information
scottnath authored and gkorland committed Mar 30, 2024
1 parent 16a234f commit 3505420
Show file tree
Hide file tree
Showing 11 changed files with 567 additions and 40 deletions.
265 changes: 265 additions & 0 deletions docs/docs/integrations/tools/you.ipynb

Large diffs are not rendered by default.

17 changes: 16 additions & 1 deletion libs/community/langchain_community/retrievers/you.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any, List

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

Expand All @@ -21,3 +24,15 @@ def _get_relevant_documents(
**kwargs: Any,
) -> List[Document]:
return self.results(query, run_manager=run_manager.get_child(), **kwargs)

async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
results = await self.results_async(
query, run_manager=run_manager.get_child(), **kwargs
)
return results
9 changes: 9 additions & 0 deletions libs/community/langchain_community/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,12 @@ def _import_yahoo_finance_news() -> Any:
return YahooFinanceNewsTool


def _import_you_tool() -> Any:
from langchain_community.tools.you.tool import YouSearchTool

return YouSearchTool


def _import_youtube_search() -> Any:
from langchain_community.tools.youtube.search import YouTubeSearchTool

Expand Down Expand Up @@ -1055,6 +1061,8 @@ def __getattr__(name: str) -> Any:
return _import_wolfram_alpha_tool()
elif name == "YahooFinanceNewsTool":
return _import_yahoo_finance_news()
elif name == "YouSearchTool":
return _import_you_tool()
elif name == "YouTubeSearchTool":
return _import_youtube_search()
elif name == "ZapierNLAListActions":
Expand Down Expand Up @@ -1192,6 +1200,7 @@ def __getattr__(name: str) -> Any:
"WolframAlphaQueryRun",
"WriteFileTool",
"YahooFinanceNewsTool",
"YouSearchTool",
"YouTubeSearchTool",
"ZapierNLAListActions",
"ZapierNLARunAction",
Expand Down
8 changes: 8 additions & 0 deletions libs/community/langchain_community/tools/you/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""You.com API toolkit."""


from langchain_community.tools.you.tool import YouSearchTool

__all__ = [
"YouSearchTool",
]
43 changes: 43 additions & 0 deletions libs/community/langchain_community/tools/you/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import List, Optional, Type

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool

from langchain_community.utilities.you import YouSearchAPIWrapper


class YouInput(BaseModel):
query: str = Field(description="should be a search query")


class YouSearchTool(BaseTool):
"""Tool that searches the you.com API"""

name = "you_search"
description = (
"The YOU APIs make LLMs and search experiences more factual and"
"up to date with realtime web data."
)
args_schema: Type[BaseModel] = YouInput
api_wrapper: YouSearchAPIWrapper = Field(default_factory=YouSearchAPIWrapper)

def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> List[Document]:
"""Use the you.com tool."""
return self.api_wrapper.results(query)

async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> List[Document]:
"""Use the you.com tool asynchronously."""
return await self.api_wrapper.results_async(query)
79 changes: 41 additions & 38 deletions libs/community/langchain_community/utilities/you.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
In order to set this up, follow instructions at:
"""
import json
from typing import Any, Dict, List, Literal, Optional

import aiohttp
Expand Down Expand Up @@ -113,16 +112,16 @@ def _parse_results(self, raw_search_results: Dict) -> List[Document]:

docs = []
for hit in raw_search_results["hits"]:
n_snippets_per_hit = self.n_snippets_per_hit or len(hit["snippets"])
for snippet in hit["snippets"][:n_snippets_per_hit]:
n_snippets_per_hit = self.n_snippets_per_hit or len(hit.get("snippets"))
for snippet in hit.get("snippets")[:n_snippets_per_hit]:
docs.append(
Document(
page_content=snippet,
metadata={
"url": hit["url"],
"thumbnail_url": hit["thumbnail_url"],
"title": hit["title"],
"description": hit["description"],
"url": hit.get("url"),
"thumbnail_url": hit.get("thumbnail_url"),
"title": hit.get("title"),
"description": hit.get("description"),
},
)
)
Expand Down Expand Up @@ -188,43 +187,47 @@ def results(
async def raw_results_async(
self,
query: str,
num_web_results: Optional[int] = 5,
safesearch: Optional[str] = "moderate",
country: Optional[str] = "US",
**kwargs: Any,
) -> Dict:
"""Get results from the you.com Search API asynchronously."""

# Function to perform the API call
async def fetch() -> str:
params = {
"query": query,
"num_web_results": num_web_results,
"safesearch": safesearch,
"country": country,
}
async with aiohttp.ClientSession() as session:
async with session.post(f"{YOU_API_URL}/search", json=params) as res:
if res.status == 200:
data = await res.text()
return data
else:
raise Exception(f"Error {res.status}: {res.reason}")

results_json_str = await fetch()
return json.loads(results_json_str)
headers = {"X-API-Key": self.ydc_api_key or ""}
params = {
"query": query,
"num_web_results": self.num_web_results,
"safesearch": self.safesearch,
"country": self.country,
**kwargs,
}
params = {k: v for k, v in params.items() if v is not None}
# news endpoint expects `q` instead of `query`
if self.endpoint_type == "news":
params["q"] = params["query"]
del params["query"]

# @todo deprecate `snippet`, not part of API
if self.endpoint_type == "snippet":
self.endpoint_type = "search"

async with aiohttp.ClientSession() as session:
async with session.get(
url=f"{YOU_API_URL}/{self.endpoint_type}",
params=params,
headers=headers,
) as res:
if res.status == 200:
results = await res.json()
return results
else:
raise Exception(f"Error {res.status}: {res.reason}")

async def results_async(
self,
query: str,
num_web_results: Optional[int] = 5,
safesearch: Optional[str] = "moderate",
country: Optional[str] = "US",
**kwargs: Any,
) -> List[Document]:
results_json = await self.raw_results_async(
query=query,
num_web_results=num_web_results,
safesearch=safesearch,
country=country,
raw_search_results_async = await self.raw_results_async(
query,
**{key: value for key, value in kwargs.items() if value is not None},
)

return self._parse_results(results_json["results"])
return self._parse_results(raw_search_results_async)
39 changes: 39 additions & 0 deletions libs/community/tests/unit_tests/retrievers/test_you.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from unittest.mock import AsyncMock, patch

import pytest
import responses

from langchain_community.retrievers.you import YouRetriever
Expand Down Expand Up @@ -70,3 +73,39 @@ def test_invoke_news(self) -> None:
results = you_wrapper.results(query)
expected_result = NEWS_RESPONSE_PARSED
assert results == expected_result

@pytest.mark.asyncio
async def test_aget_relevant_documents(self) -> None:
instance = YouRetriever(ydc_api_key="test_api_key")

# Mock response object to simulate aiohttp response
mock_response = AsyncMock()
mock_response.__aenter__.return_value = (
mock_response # Make the context manager return itself
)
mock_response.__aexit__.return_value = None # No value needed for exit
mock_response.status = 200
mock_response.json = AsyncMock(return_value=MOCK_RESPONSE_RAW)

# Patch the aiohttp.ClientSession object
with patch("aiohttp.ClientSession.get", return_value=mock_response):
results = await instance.aget_relevant_documents("test query")
assert results == MOCK_PARSED_OUTPUT

@pytest.mark.asyncio
async def test_ainvoke(self) -> None:
instance = YouRetriever(ydc_api_key="test_api_key")

# Mock response object to simulate aiohttp response
mock_response = AsyncMock()
mock_response.__aenter__.return_value = (
mock_response # Make the context manager return itself
)
mock_response.__aexit__.return_value = None # No value needed for exit
mock_response.status = 200
mock_response.json = AsyncMock(return_value=MOCK_RESPONSE_RAW)

# Patch the aiohttp.ClientSession object
with patch("aiohttp.ClientSession.get", return_value=mock_response):
results = await instance.ainvoke("test query")
assert results == MOCK_PARSED_OUTPUT
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/tools/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
"WolframAlphaQueryRun",
"WriteFileTool",
"YahooFinanceNewsTool",
"YouSearchTool",
"YouTubeSearchTool",
"ZapierNLAListActions",
"ZapierNLARunAction",
Expand Down
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/tools/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
"WolframAlphaQueryRun",
"WriteFileTool",
"YahooFinanceNewsTool",
"YouSearchTool",
"YouTubeSearchTool",
"ZapierNLAListActions",
"ZapierNLARunAction",
Expand Down
87 changes: 87 additions & 0 deletions libs/community/tests/unit_tests/tools/test_you.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from unittest.mock import AsyncMock, patch

import pytest
import responses

from langchain_community.tools.you import YouSearchTool
from langchain_community.utilities.you import YouSearchAPIWrapper

from ..utilities.test_you import (
LIMITED_PARSED_OUTPUT,
MOCK_PARSED_OUTPUT,
MOCK_RESPONSE_RAW,
NEWS_RESPONSE_PARSED,
NEWS_RESPONSE_RAW,
TEST_ENDPOINT,
)


class TestYouSearchTool:
@responses.activate
def test_invoke(self) -> None:
responses.add(
responses.GET, f"{TEST_ENDPOINT}/search", json=MOCK_RESPONSE_RAW, status=200
)
query = "Test query text"
you_tool = YouSearchTool(api_wrapper=YouSearchAPIWrapper(ydc_api_key="test"))
results = you_tool.invoke(query)
expected_result = MOCK_PARSED_OUTPUT
assert results == expected_result

@responses.activate
def test_invoke_max_docs(self) -> None:
responses.add(
responses.GET, f"{TEST_ENDPOINT}/search", json=MOCK_RESPONSE_RAW, status=200
)
query = "Test query text"
you_tool = YouSearchTool(
api_wrapper=YouSearchAPIWrapper(ydc_api_key="test", k=2)
)
results = you_tool.invoke(query)
expected_result = [MOCK_PARSED_OUTPUT[0], MOCK_PARSED_OUTPUT[1]]
assert results == expected_result

@responses.activate
def test_invoke_limit_snippets(self) -> None:
responses.add(
responses.GET, f"{TEST_ENDPOINT}/search", json=MOCK_RESPONSE_RAW, status=200
)
query = "Test query text"
you_tool = YouSearchTool(
api_wrapper=YouSearchAPIWrapper(ydc_api_key="test", n_snippets_per_hit=1)
)
results = you_tool.invoke(query)
expected_result = LIMITED_PARSED_OUTPUT
assert results == expected_result

@responses.activate
def test_invoke_news(self) -> None:
responses.add(
responses.GET, f"{TEST_ENDPOINT}/news", json=NEWS_RESPONSE_RAW, status=200
)

query = "Test news text"
you_tool = YouSearchTool(
api_wrapper=YouSearchAPIWrapper(ydc_api_key="test", endpoint_type="news")
)
results = you_tool.invoke(query)
expected_result = NEWS_RESPONSE_PARSED
assert results == expected_result

@pytest.mark.asyncio
async def test_ainvoke(self) -> None:
you_tool = YouSearchTool(api_wrapper=YouSearchAPIWrapper(ydc_api_key="test"))

# Mock response object to simulate aiohttp response
mock_response = AsyncMock()
mock_response.__aenter__.return_value = (
mock_response # Make the context manager return itself
)
mock_response.__aexit__.return_value = None # No value needed for exit
mock_response.status = 200
mock_response.json = AsyncMock(return_value=MOCK_RESPONSE_RAW)

# Patch the aiohttp.ClientSession object
with patch("aiohttp.ClientSession.get", return_value=mock_response):
results = await you_tool.ainvoke("test query")
assert results == MOCK_PARSED_OUTPUT

0 comments on commit 3505420

Please sign in to comment.