Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add local agent #23438

Merged
merged 2 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/en/main_classes/agent.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ contains the API docs for the underlying classes.

## Agents

We provide two types of agents: [`HfAgent`] uses inference endpoints for opensource models and [`OpenAiAgent`] uses OpenAI closed models.
We provide three types of agents: [`HfAgent`] uses inference endpoints for opensource models, [`LocalAgent`] uses a model of your choice locally and [`OpenAiAgent`] uses OpenAI closed models.

### HfAgent

[[autodoc]] HfAgent

### LocalAgent

[[autodoc]] LocalAgent

### OpenAiAgent

[[autodoc]] OpenAiAgent
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@
"tools": [
"Agent",
"HfAgent",
"LocalAgent",
"OpenAiAgent",
"PipelineTool",
"RemoteTool",
Expand Down Expand Up @@ -4361,7 +4362,17 @@
)

# Tools
from .tools import Agent, HfAgent, OpenAiAgent, PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
from .tools import (
Agent,
HfAgent,
LocalAgent,
OpenAiAgent,
PipelineTool,
RemoteTool,
Tool,
launch_gradio_demo,
load_tool,
)

# Trainer
from .trainer_callback import (
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


_import_structure = {
"agents": ["Agent", "HfAgent", "OpenAiAgent"],
"agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"],
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
}

Expand All @@ -46,7 +46,7 @@
_import_structure["translation"] = ["TranslationTool"]

if TYPE_CHECKING:
from .agents import Agent, HfAgent, OpenAiAgent
from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool

try:
Expand Down
113 changes: 113 additions & 0 deletions src/transformers/tools/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import requests
from huggingface_hub import HfFolder, hf_hub_download, list_spaces

from ..generation import StoppingCriteria, StoppingCriteriaList
from ..models.auto import AutoModelForCausalLM, AutoTokenizer
from ..utils import is_openai_available, logging
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
Expand Down Expand Up @@ -492,3 +494,114 @@ def generate_one(self, prompt, stop):
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result


class LocalAgent(Agent):
"""
Agent that uses a local model and tokenizer to generate code.
Args:
model ([`PreTrainedModel`]):
The model to use for the agent.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer to use for the agent.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent
checkpoint = "bigcode/starcoder"
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
agent = LocalAgent(model, tokenizer)
agent.run("Draw me a picture of rivers and lakes.")
```
"""

def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
self.model = model
self.tokenizer = tokenizer
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Convenience method to build a `LocalAgent` from a pretrained checkpoint.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
kwargs:
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].
Example:
```py
import torch
from transformers import LocalAgent
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
agent.run("Draw me a picture of rivers and lakes.")
```
"""
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
Comment on lines +561 to +562
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice shortcut!

return cls(model, tokenizer)

@property
def _model_device(self):
if hasattr(self.model, "hf_device_map"):
return list(self.model.hf_device_map.values())[0]
for param in self.mode.parameters():
return param.device

def generate_one(self, prompt, stop):
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
src_len = encoded_inputs["input_ids"].shape[1]
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
outputs = self.model.generate(
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
)

result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
result = result[: -len(stop_seq)]
return result


class StopSequenceCriteria(StoppingCriteria):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we:

  1. move it to the stopping criteria file? (I've seen a user asking for this)
  2. generalize it for the batched input case? (or add a todo, I'd be happy to expand it :D )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can move it as you add support for the batched case? I don't need it for the agents and it's not an obvious thing to do (that's the reason I didn't put this in the stopping criteria file by the way).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can take care of it afterwards 👍

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

"""
This class can be used to stop generation whenever a sequence of tokens is encountered.
Args:
stop_sequences (`str` or `List[str]`):
The sequence (or list of sequences) on which to stop execution.
tokenizer:
The tokenizer used to decode the model outputs.
"""

def __init__(self, stop_sequences, tokenizer):
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
self.stop_sequences = stop_sequences
self.tokenizer = tokenizer

def __call__(self, input_ids, scores, **kwargs) -> bool:
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)