Skip to content

Commit

Permalink
Evaluate on Version (#18471)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Mar 4, 2024
1 parent 55b69d5 commit 1eec67e
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions libs/langchain/langchain/smith/evaluation/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,20 +968,24 @@ def _run_llm_or_chain(
return result


## Public API


def _prepare_eval_run(
client: Client,
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
project_name: str,
project_metadata: Optional[Dict[str, Any]] = None,
tags: Optional[List[str]] = None,
dataset_version: Optional[Union[str, datetime]] = None,
) -> Tuple[MCF, TracerSession, Dataset, List[Example]]:
wrapped_model = _wrap_in_chain_factory(llm_or_chain_factory, dataset_name)
dataset = client.read_dataset(dataset_name=dataset_name)
examples = list(client.list_examples(dataset_id=dataset.id))
as_of = dataset_version if isinstance(dataset_version, datetime) else None
if isinstance(dataset_version, str):
raise NotImplementedError(
"Selecting dataset_version by tag is not yet supported."
" Please use a datetime object."
)
examples = list(client.list_examples(dataset_id=dataset.id, as_of=as_of))
if not examples:
raise ValueError(f"Dataset {dataset_name} has no example rows.")
modified_at = [ex.modified_at for ex in examples if ex.modified_at]
Expand Down Expand Up @@ -1173,6 +1177,7 @@ def prepare(
concurrency_level: int = 5,
project_metadata: Optional[Dict[str, Any]] = None,
revision_id: Optional[str] = None,
dataset_version: Optional[Union[datetime, str]] = None,
) -> _DatasetRunContainer:
project_name = project_name or name_generation.random_name()
if revision_id:
Expand All @@ -1186,6 +1191,7 @@ def prepare(
project_name,
project_metadata=project_metadata,
tags=tags,
dataset_version=dataset_version,
)
tags = tags or []
for k, v in (project.metadata.get("git") or {}).items():
Expand Down Expand Up @@ -1269,18 +1275,20 @@ def _display_aggregate_results(aggregate_results: pd.DataFrame) -> None:
"langchain.schema.runnable.base.RunnableLambda.html)"
)

## Public API


async def arun_on_dataset(
client: Optional[Client],
dataset_name: str,
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
evaluation: Optional[smith_eval.RunEvalConfig] = None,
dataset_version: Optional[Union[datetime, str]] = None,
concurrency_level: int = 5,
project_name: Optional[str] = None,
project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
revision_id: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, Any]:
Expand All @@ -1289,6 +1297,13 @@ async def arun_on_dataset(
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
if revision_id is None:
revision_id = get_langchain_env_var_metadata().get("revision_id")
tags = kwargs.pop("tags", None)
if tags:
warn_deprecated(
"0.1.9",
message="The tags argument is deprecated and will be"
" removed in a future release. Please specify project_metadata instead.",
)

if kwargs:
warn_deprecated(
Expand All @@ -1310,6 +1325,7 @@ async def arun_on_dataset(
concurrency_level,
project_metadata=project_metadata,
revision_id=revision_id,
dataset_version=dataset_version,
)
batch_results = await runnable_utils.gather_with_concurrency(
container.configs[0].get("max_concurrency"),
Expand All @@ -1332,17 +1348,24 @@ def run_on_dataset(
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
evaluation: Optional[smith_eval.RunEvalConfig] = None,
dataset_version: Optional[Union[datetime, str]] = None,
concurrency_level: int = 5,
project_name: Optional[str] = None,
project_metadata: Optional[Dict[str, Any]] = None,
verbose: bool = False,
tags: Optional[List[str]] = None,
revision_id: Optional[str] = None,
**kwargs: Any,
) -> Dict[str, Any]:
input_mapper = kwargs.pop("input_mapper", None)
if input_mapper:
warn_deprecated("0.0.305", message=_INPUT_MAPPER_DEP_WARNING, pending=True)
tags = kwargs.pop("tags", None)
if tags:
warn_deprecated(
"0.1.9",
message="The tags argument is deprecated and will be"
" removed in a future release. Please specify project_metadata instead.",
)
if revision_id is None:
revision_id = get_langchain_env_var_metadata().get("revision_id")

Expand All @@ -1366,6 +1389,7 @@ def run_on_dataset(
concurrency_level,
project_metadata=project_metadata,
revision_id=revision_id,
dataset_version=dataset_version,
)
if concurrency_level == 0:
batch_results = [
Expand Down Expand Up @@ -1458,8 +1482,8 @@ def construct_chain():
client = Client()
run_on_dataset(
client,
"<my_dataset_name>",
construct_chain,
dataset_name="<my_dataset_name>",
llm_or_chain_factory=construct_chain,
evaluation=evaluation_config,
)
Expand Down Expand Up @@ -1496,8 +1520,8 @@ def _evaluate_strings(self, prediction, reference=None, input=None, **kwargs) ->
run_on_dataset(
client,
"<my_dataset_name>",
construct_chain,
dataset_name="<my_dataset_name>",
llm_or_chain_factory=construct_chain,
evaluation=evaluation_config,
)
""" # noqa: E501
Expand Down

0 comments on commit 1eec67e

Please sign in to comment.