Skip to content

Commit

Permalink
experimental: docstrings update (#18048)
Browse files Browse the repository at this point in the history
Added missed docstrings. Formatted docsctrings to the consistent format.
  • Loading branch information
leo-gan committed Feb 24, 2024
1 parent 56b955f commit 3f6bf85
Show file tree
Hide file tree
Showing 61 changed files with 316 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class AutoGPT:
"""Agent class for interacting with Auto-GPT."""
"""Agent for interacting with AutoGPT."""

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class PromptGenerator:
"""A class for generating custom prompt strings.
"""Generator of custom prompt strings.
Does this based on constraints, commands, resources, and performance evaluations.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@


class HuggingGPT:
"""Agent for interacting with HuggingGPT."""

def __init__(self, llm: BaseLanguageModel, tools: List[BaseTool]):
self.llm = llm
self.tools = tools
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain:


class ResponseGenerator:
"""Generates a response based on the input."""

def __init__(self, llm_chain: LLMChain, stop: Optional[List] = None):
self.llm_chain = llm_chain
self.stop = stop
Expand All @@ -36,6 +38,8 @@ def generate(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) ->


def load_response_generator(llm: BaseLanguageModel) -> ResponseGenerator:
"""Load the ResponseGenerator."""

llm_chain = ResponseGenerationChain.from_llm(llm)
return ResponseGenerator(
llm_chain=llm_chain,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class Task:
"""Task to be executed."""

def __init__(self, task: str, id: int, dep: List[int], args: Dict, tool: BaseTool):
self.task = task
self.id = id
Expand Down Expand Up @@ -74,7 +76,7 @@ def run(self) -> str:


class TaskExecutor:
"""Load tools to execute tasks."""
"""Load tools and execute tasks."""

def __init__(self, plan: Plan):
self.plan = plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def from_llm(


class Step:
"""A step in the plan."""

def __init__(
self, task: str, id: int, dep: List[int], args: Dict[str, str], tool: BaseTool
):
Expand All @@ -87,6 +89,8 @@ def __init__(


class Plan:
"""A plan to execute."""

def __init__(self, steps: List[Step]):
self.steps = steps

Expand All @@ -98,6 +102,8 @@ def __repr__(self) -> str:


class BasePlanner(BaseModel):
"""Base class for a planner."""

@abstractmethod
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
"""Given input, decide what to do."""
Expand All @@ -106,11 +112,22 @@ def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan
async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan:
"""Given input, decide what to do."""
"""Asynchronous Given input, decide what to do."""


class PlanningOutputParser(BaseModel):
"""Parses the output of the planning stage."""

def parse(self, text: str, hf_tools: List[BaseTool]) -> Plan:
"""Parse the output of the planning stage.
Args:
text: The output of the planning stage.
hf_tools: The tools available.
Returns:
The plan.
"""
steps = []
for v in json.loads(re.findall(r"\[.*\]", text)[0]):
choose_tool = None
Expand All @@ -124,6 +141,8 @@ def parse(self, text: str, hf_tools: List[BaseTool]) -> Plan:


class TaskPlanner(BasePlanner):
"""Planner for tasks."""

llm_chain: LLMChain
output_parser: PlanningOutputParser
stop: Optional[List] = None
Expand All @@ -139,7 +158,7 @@ def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan
async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan:
"""Given input, decided what to do."""
"""Asynchronous Given input, decided what to do."""
inputs["hf_tools"] = [
f"{tool.name}: {tool.description}" for tool in inputs["hf_tools"]
]
Expand All @@ -150,5 +169,7 @@ async def aplan(


def load_chat_planner(llm: BaseLanguageModel) -> TaskPlanner:
"""Load the chat planner."""

llm_chain = TaskPlaningChain.from_llm(llm)
return TaskPlanner(llm_chain=llm_chain, output_parser=PlanningOutputParser())
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@


class ChatWrapper(BaseChatModel):
"""Wrapper for chat LLMs."""

llm: LLM
sys_beg: str
sys_end: str
Expand Down Expand Up @@ -130,6 +132,8 @@ def _to_chat_result(llm_result: LLMResult) -> ChatResult:


class Llama2Chat(ChatWrapper):
"""Wrapper for Llama-2-chat model."""

@property
def _llm_type(self) -> str:
return "llama-2-chat"
Expand All @@ -145,6 +149,8 @@ def _llm_type(self) -> str:


class Orca(ChatWrapper):
"""Wrapper for Orca-style models."""

@property
def _llm_type(self) -> str:
return "orca-style"
Expand All @@ -158,6 +164,8 @@ def _llm_type(self) -> str:


class Vicuna(ChatWrapper):
"""Wrapper for Vicuna-style models."""

@property
def _llm_type(self) -> str:
return "vicuna-style"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@


class AmazonComprehendModerationChain(Chain):
"""A subclass of Chain, designed to apply moderation to LLMs."""
"""Moderation Chain, based on `Amazon Comprehend` service.
See more at https://aws.amazon.com/comprehend/
"""

output_key: str = "output" #: :meta private:
"""Key used to fetch/store the output in data containers. Defaults to `output`"""
Expand Down Expand Up @@ -54,7 +57,7 @@ class AmazonComprehendModerationChain(Chain):
@root_validator(pre=True)
def create_client(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Creates an Amazon Comprehend client
Creates an Amazon Comprehend client.
Args:
values (Dict[str, Any]): A dictionary containing configuration values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class BaseModeration:
"""Base class for moderation."""

def __init__(
self,
client: Any,
Expand Down Expand Up @@ -109,6 +111,8 @@ def _log_message_for_verbose(self, message: str) -> None:
self.run_manager.on_text(message)

def moderate(self, prompt: Any) -> str:
"""Moderate the input prompt."""

from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
ModerationPiiConfig,
ModerationPromptSafetyConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


class BaseModerationCallbackHandler:
"""Base class for moderation callback handlers."""

def __init__(self) -> None:
if (
self._is_method_unchanged(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@


class ModerationPiiConfig(BaseModel):
"""Configuration for PII moderation filter."""

threshold: float = 0.5
"""Threshold for PII confidence score, defaults to 0.5 i.e. 50%"""

Expand All @@ -21,6 +23,8 @@ class ModerationPiiConfig(BaseModel):


class ModerationToxicityConfig(BaseModel):
"""Configuration for Toxicity moderation filter."""

threshold: float = 0.5
"""Threshold for Toxic label confidence score, defaults to 0.5 i.e. 50%"""

Expand All @@ -29,6 +33,8 @@ class ModerationToxicityConfig(BaseModel):


class ModerationPromptSafetyConfig(BaseModel):
"""Configuration for Prompt Safety moderation filter."""

threshold: float = 0.5
"""
Threshold for Prompt Safety classification
Expand All @@ -37,6 +43,8 @@ class ModerationPromptSafetyConfig(BaseModel):


class BaseModerationConfig(BaseModel):
"""Base configuration settings for moderation."""

filters: List[
Union[
ModerationPiiConfig, ModerationToxicityConfig, ModerationPromptSafetyConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(


class ModerationPromptSafetyError(Exception):
"""Exception raised if Intention entities are detected.
"""Exception raised if Unsafe prompts are detected.
Attributes:
message -- explanation of the error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@


class ComprehendPII:
"""Class to handle Personally Identifiable Information (PII) moderation."""

def __init__(
self,
client: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@


class ComprehendPromptSafety:
"""Class to handle prompt safety moderation."""

def __init__(
self,
client: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@


class ComprehendToxicity:
"""Class to handle toxicity moderation."""

def __init__(
self,
client: Any,
Expand Down
2 changes: 1 addition & 1 deletion libs/experimental/langchain_experimental/cpal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _call(


class NarrativeChain(_BaseStoryElementChain):
"""Decompose the narrative into its story elements
"""Decompose the narrative into its story elements.
- causal model
- query
Expand Down
23 changes: 17 additions & 6 deletions libs/experimental/langchain_experimental/cpal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class NarrativeModel(BaseModel):
"""
Represent the narrative input as three story elements.
Narrative input as three story elements.
"""

story_outcome_question: str
Expand All @@ -33,6 +33,8 @@ def empty_str_to_none(cls, v: str) -> Union[str, None]:


class EntityModel(BaseModel):
"""Entity in the story."""

name: str = Field(description="entity name")
code: str = Field(description="entity actions")
value: float = Field(description="entity initial value")
Expand All @@ -51,14 +53,17 @@ def lower_case_name(cls, v: str) -> str:


class CausalModel(BaseModel):
"""Casual data."""

attribute: str = Field(description="name of the attribute to be calculated")
entities: List[EntityModel] = Field(description="entities in the story")

# TODO: root validate each `entity.depends_on` using system's entity names


class EntitySettingModel(BaseModel):
"""
"""Entity initial conditions.
Initial conditions for an entity
{"name": "bud", "attribute": "pet_count", "value": 12}
Expand All @@ -75,7 +80,8 @@ def lower_case_transform(cls, v: str) -> str:


class SystemSettingModel(BaseModel):
"""
"""System initial conditions.
Initial global conditions for the system.
{"parameter": "interest_rate", "value": .05}
Expand All @@ -86,8 +92,7 @@ class SystemSettingModel(BaseModel):


class InterventionModel(BaseModel):
"""
aka initial conditions
"""Intervention data of the story aka initial conditions.
>>> intervention.dict()
{
Expand All @@ -110,7 +115,9 @@ def lower_case_name(cls, v: str) -> Union[str, None]:


class QueryModel(BaseModel):
"""translate a question about the story outcome into a programmatic expression"""
"""Query data of the story.
translate a question about the story outcome into a programmatic expression"""

question: str = Field(alias=Constant.narrative_input.value) # input
expression: str # output, part of llm completion
Expand All @@ -119,11 +126,15 @@ class QueryModel(BaseModel):


class ResultModel(BaseModel):
"""Result of the story query."""

question: str = Field(alias=Constant.narrative_input.value) # input
_result_table: str = PrivateAttr() # result of the executed query


class StoryModel(BaseModel):
"""Story data."""

causal_operations: Any = Field(required=True)
intervention: Any = Field(required=True)
query: Any = Field(required=True)
Expand Down

0 comments on commit 3f6bf85

Please sign in to comment.