Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community [patch] : adds feedback and status for Fiddler callback handler events #19157

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
124 changes: 91 additions & 33 deletions libs/community/langchain_community/callbacks/fiddler_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from uuid import UUID

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
Expand All @@ -15,6 +16,11 @@
COMPLETION_TOKENS = "completion_tokens"
RUN_ID = "run_id"
MODEL_NAME = "model_name"
GOOD = "good"
BAD = "bad"
NEUTRAL = "neutral"
SUCCESS = "success"
FAILURE = "failure"

# Default values
DEFAULT_MAX_TOKEN = 65536
Expand All @@ -23,12 +29,20 @@
# Fiddler specific constants
PROMPT = "prompt"
RESPONSE = "response"
CONTEXT = "context"
DURATION = "duration"
FEEDBACK = "feedback"
LLM_STATUS = "llm_status"

FEEDBACK_POSSIBLE_VALUES = [GOOD, BAD, NEUTRAL]

# Define a dataset dictionary
_dataset_dict = {
PROMPT: ["fiddler"] * 10,
RESPONSE: ["fiddler"] * 10,
CONTEXT: ["fiddler"] * 10,
FEEDBACK: ["good"] * 10,
LLM_STATUS: ["success"] * 10,
MODEL_NAME: ["fiddler"] * 10,
RUN_ID: ["123e4567-e89b-12d3-a456-426614174000"] * 10,
TOTAL_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
Expand Down Expand Up @@ -83,8 +97,9 @@ def __init__(
self.api_key = api_key
self._df = self.pd.DataFrame(_dataset_dict)

self.run_id_prompts: Dict[str, List[str]] = {}
self.run_id_starttime: Dict[str, int] = {}
self.run_id_prompts: Dict[UUID, List[str]] = {}
self.run_id_response: Dict[UUID, List[str]] = {}
self.run_id_starttime: Dict[UUID, int] = {}

# Initialize Fiddler client here
self.fiddler_client = self.fdl.FiddlerApi(url, org_id=org, auth_token=api_key)
Expand All @@ -105,6 +120,17 @@ def __init__(
dataset_info = self.fdl.DatasetInfo.from_dataframe(
self._df, max_inferred_cardinality=0
)

# Set feedback column to categorical
for i in range(len(dataset_info.columns)):
if dataset_info.columns[i].name == FEEDBACK:
dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY
dataset_info.columns[i].possible_values = FEEDBACK_POSSIBLE_VALUES

elif dataset_info.columns[i].name == LLM_STATUS:
dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY
dataset_info.columns[i].possible_values = [SUCCESS, FAILURE]

if self.model not in self.fiddler_client.get_dataset_names(self.project):
print( # noqa: T201
f"adding dataset {self.model} to project {self.project}."
Expand All @@ -128,13 +154,15 @@ def __init__(
dataset_info=dataset_info,
dataset_id="train",
model_task=self.fdl.ModelTask.LLM,
features=[PROMPT, RESPONSE],
features=[PROMPT, RESPONSE, CONTEXT],
target=FEEDBACK,
metadata_cols=[
RUN_ID,
TOTAL_TOKENS,
PROMPT_TOKENS,
COMPLETION_TOKENS,
MODEL_NAME,
DURATION,
],
custom_features=self.custom_features,
)
Expand Down Expand Up @@ -228,6 +256,42 @@ def custom_features(self) -> list:
),
]

def _publish_events(
self,
run_id: UUID,
prompt_responses: List[str],
duration: int,
llm_status: str,
model_name: Optional[str] = "",
token_usage_dict: Optional[Dict[str, Any]] = None,
) -> None:
"""
Publish events to fiddler
"""

prompt_count = len(self.run_id_prompts[run_id])
df = self.pd.DataFrame(
{
PROMPT: self.run_id_prompts[run_id],
RESPONSE: prompt_responses,
RUN_ID: [str(run_id)] * prompt_count,
DURATION: [duration] * prompt_count,
LLM_STATUS: [llm_status] * prompt_count,
MODEL_NAME: [model_name] * prompt_count,
}
)

if token_usage_dict:
for key, value in token_usage_dict.items():
df[key] = [value] * prompt_count if isinstance(value, int) else value

try:
self.fiddler_client.publish_events_batch(self.project, self.model, df)
except Exception as e:
print( # noqa: T201
f"Error publishing events to fiddler: {e}. continuing..."
)

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
Expand All @@ -237,42 +301,36 @@ def on_llm_start(

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
flattened_llmresult = response.flatten()
token_usage_dict = {}
run_id = kwargs[RUN_ID]
run_duration = self.run_id_starttime[run_id] - int(time.time())
prompt_responses = []
model_name = ""
token_usage_dict = {}

if isinstance(response.llm_output, dict):
if TOKEN_USAGE in response.llm_output:
token_usage_dict = response.llm_output[TOKEN_USAGE]
if MODEL_NAME in response.llm_output:
model_name = response.llm_output[MODEL_NAME]

for llmresult in flattened_llmresult:
prompt_responses.append(llmresult.generations[0][0].text)

df = self.pd.DataFrame(
{
PROMPT: self.run_id_prompts[run_id],
RESPONSE: prompt_responses,
token_usage_dict = {
k: v
for k, v in response.llm_output.items()
if k in [TOTAL_TOKENS, PROMPT_TOKENS, COMPLETION_TOKENS]
}
)

if TOTAL_TOKENS in token_usage_dict:
df[PROMPT_TOKENS] = int(token_usage_dict[TOTAL_TOKENS])
model_name = response.llm_output.get(MODEL_NAME, "")

if PROMPT_TOKENS in token_usage_dict:
df[TOTAL_TOKENS] = int(token_usage_dict[PROMPT_TOKENS])
prompt_responses = [
llmresult.generations[0][0].text for llmresult in flattened_llmresult
]

if COMPLETION_TOKENS in token_usage_dict:
df[COMPLETION_TOKENS] = token_usage_dict[COMPLETION_TOKENS]
self._publish_events(
run_id,
prompt_responses,
run_duration,
SUCCESS,
model_name,
token_usage_dict,
)

df[MODEL_NAME] = model_name
df[RUN_ID] = str(run_id)
df[DURATION] = run_duration
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
run_id = kwargs[RUN_ID]
duration = int(time.time()) - self.run_id_starttime[run_id]

try:
self.fiddler_client.publish_events_batch(self.project, self.model, df)
except Exception as e:
print(f"Error publishing events to fiddler: {e}. continuing...") # noqa: T201
self._publish_events(
run_id, [""] * len(self.run_id_prompts[run_id]), duration, FAILURE
)