Skip to content

Commit

Permalink
Add local agent (#23438)
Browse files Browse the repository at this point in the history
* Add local agent

* Document LocalAgent
  • Loading branch information
sgugger committed May 18, 2023
1 parent db13634 commit cf43200
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 4 deletions.
6 changes: 5 additions & 1 deletion docs/source/en/main_classes/agent.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@ contains the API docs for the underlying classes.

## Agents

We provide two types of agents: [`HfAgent`] uses inference endpoints for opensource models and [`OpenAiAgent`] uses OpenAI closed models.
We provide three types of agents: [`HfAgent`] uses inference endpoints for opensource models, [`LocalAgent`] uses a model of your choice locally and [`OpenAiAgent`] uses OpenAI closed models.

### HfAgent

[[autodoc]] HfAgent

### LocalAgent

[[autodoc]] LocalAgent

### OpenAiAgent

[[autodoc]] OpenAiAgent
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@
"tools": [
"Agent",
"HfAgent",
"LocalAgent",
"OpenAiAgent",
"PipelineTool",
"RemoteTool",
Expand Down Expand Up @@ -4361,7 +4362,17 @@
)

# Tools
from .tools import Agent, HfAgent, OpenAiAgent, PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool
from .tools import (
Agent,
HfAgent,
LocalAgent,
OpenAiAgent,
PipelineTool,
RemoteTool,
Tool,
launch_gradio_demo,
load_tool,
)

# Trainer
from .trainer_callback import (
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


_import_structure = {
"agents": ["Agent", "HfAgent", "OpenAiAgent"],
"agents": ["Agent", "HfAgent", "LocalAgent", "OpenAiAgent"],
"base": ["PipelineTool", "RemoteTool", "Tool", "launch_gradio_demo", "load_tool"],
}

Expand All @@ -46,7 +46,7 @@
_import_structure["translation"] = ["TranslationTool"]

if TYPE_CHECKING:
from .agents import Agent, HfAgent, OpenAiAgent
from .agents import Agent, HfAgent, LocalAgent, OpenAiAgent
from .base import PipelineTool, RemoteTool, Tool, launch_gradio_demo, load_tool

try:
Expand Down
113 changes: 113 additions & 0 deletions src/transformers/tools/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import requests
from huggingface_hub import HfFolder, hf_hub_download, list_spaces

from ..generation import StoppingCriteria, StoppingCriteriaList
from ..models.auto import AutoModelForCausalLM, AutoTokenizer
from ..utils import is_openai_available, logging
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
from .prompts import CHAT_MESSAGE_PROMPT, CHAT_PROMPT_TEMPLATE, RUN_PROMPT_TEMPLATE
Expand Down Expand Up @@ -492,3 +494,114 @@ def generate_one(self, prompt, stop):
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result


class LocalAgent(Agent):
"""
Agent that uses a local model and tokenizer to generate code.
Args:
model ([`PreTrainedModel`]):
The model to use for the agent.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer to use for the agent.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent
checkpoint = "bigcode/starcoder"
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
agent = LocalAgent(model, tokenizer)
agent.run("Draw me a picture of rivers and lakes.")
```
"""

def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
self.model = model
self.tokenizer = tokenizer
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Convenience method to build a `LocalAgent` from a pretrained checkpoint.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
kwargs:
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].
Example:
```py
import torch
from transformers import LocalAgent
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
agent.run("Draw me a picture of rivers and lakes.")
```
"""
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(model, tokenizer)

@property
def _model_device(self):
if hasattr(self.model, "hf_device_map"):
return list(self.model.hf_device_map.values())[0]
for param in self.mode.parameters():
return param.device

def generate_one(self, prompt, stop):
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
src_len = encoded_inputs["input_ids"].shape[1]
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
outputs = self.model.generate(
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
)

result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
result = result[: -len(stop_seq)]
return result


class StopSequenceCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever a sequence of tokens is encountered.
Args:
stop_sequences (`str` or `List[str]`):
The sequence (or list of sequences) on which to stop execution.
tokenizer:
The tokenizer used to decode the model outputs.
"""

def __init__(self, stop_sequences, tokenizer):
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
self.stop_sequences = stop_sequences
self.tokenizer = tokenizer

def __call__(self, input_ids, scores, **kwargs) -> bool:
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)

0 comments on commit cf43200

Please sign in to comment.