Skip to content

Commit

Permalink
community[patch] : adds feedback and status for Fiddler callback hand…
Browse files Browse the repository at this point in the history
…ler events (#19157)

**Description:** This PR adds updates the fiddler events schema to also
pass user feedback, and llm status to fiddler
   **Tickets:** [INTERNAL] FDL-17559 
   **Dependencies:**  NA
   **Twitter handle:** behalder

Co-authored-by: Barun Halder <barun@fiddler.ai>
  • Loading branch information
2 people authored and hinthornw committed Apr 26, 2024
1 parent 3d9fe2a commit e535829
Showing 1 changed file with 91 additions and 33 deletions.
124 changes: 91 additions & 33 deletions libs/community/langchain_community/callbacks/fiddler_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from uuid import UUID

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
Expand All @@ -15,6 +16,11 @@
COMPLETION_TOKENS = "completion_tokens"
RUN_ID = "run_id"
MODEL_NAME = "model_name"
GOOD = "good"
BAD = "bad"
NEUTRAL = "neutral"
SUCCESS = "success"
FAILURE = "failure"

# Default values
DEFAULT_MAX_TOKEN = 65536
Expand All @@ -23,12 +29,20 @@
# Fiddler specific constants
PROMPT = "prompt"
RESPONSE = "response"
CONTEXT = "context"
DURATION = "duration"
FEEDBACK = "feedback"
LLM_STATUS = "llm_status"

FEEDBACK_POSSIBLE_VALUES = [GOOD, BAD, NEUTRAL]

# Define a dataset dictionary
_dataset_dict = {
PROMPT: ["fiddler"] * 10,
RESPONSE: ["fiddler"] * 10,
CONTEXT: ["fiddler"] * 10,
FEEDBACK: ["good"] * 10,
LLM_STATUS: ["success"] * 10,
MODEL_NAME: ["fiddler"] * 10,
RUN_ID: ["123e4567-e89b-12d3-a456-426614174000"] * 10,
TOTAL_TOKENS: [0, DEFAULT_MAX_TOKEN] * 5,
Expand Down Expand Up @@ -83,8 +97,9 @@ def __init__(
self.api_key = api_key
self._df = self.pd.DataFrame(_dataset_dict)

self.run_id_prompts: Dict[str, List[str]] = {}
self.run_id_starttime: Dict[str, int] = {}
self.run_id_prompts: Dict[UUID, List[str]] = {}
self.run_id_response: Dict[UUID, List[str]] = {}
self.run_id_starttime: Dict[UUID, int] = {}

# Initialize Fiddler client here
self.fiddler_client = self.fdl.FiddlerApi(url, org_id=org, auth_token=api_key)
Expand All @@ -105,6 +120,17 @@ def __init__(
dataset_info = self.fdl.DatasetInfo.from_dataframe(
self._df, max_inferred_cardinality=0
)

# Set feedback column to categorical
for i in range(len(dataset_info.columns)):
if dataset_info.columns[i].name == FEEDBACK:
dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY
dataset_info.columns[i].possible_values = FEEDBACK_POSSIBLE_VALUES

elif dataset_info.columns[i].name == LLM_STATUS:
dataset_info.columns[i].data_type = self.fdl.DataType.CATEGORY
dataset_info.columns[i].possible_values = [SUCCESS, FAILURE]

if self.model not in self.fiddler_client.get_dataset_names(self.project):
print( # noqa: T201
f"adding dataset {self.model} to project {self.project}."
Expand All @@ -128,13 +154,15 @@ def __init__(
dataset_info=dataset_info,
dataset_id="train",
model_task=self.fdl.ModelTask.LLM,
features=[PROMPT, RESPONSE],
features=[PROMPT, RESPONSE, CONTEXT],
target=FEEDBACK,
metadata_cols=[
RUN_ID,
TOTAL_TOKENS,
PROMPT_TOKENS,
COMPLETION_TOKENS,
MODEL_NAME,
DURATION,
],
custom_features=self.custom_features,
)
Expand Down Expand Up @@ -228,6 +256,42 @@ def custom_features(self) -> list:
),
]

def _publish_events(
self,
run_id: UUID,
prompt_responses: List[str],
duration: int,
llm_status: str,
model_name: Optional[str] = "",
token_usage_dict: Optional[Dict[str, Any]] = None,
) -> None:
"""
Publish events to fiddler
"""

prompt_count = len(self.run_id_prompts[run_id])
df = self.pd.DataFrame(
{
PROMPT: self.run_id_prompts[run_id],
RESPONSE: prompt_responses,
RUN_ID: [str(run_id)] * prompt_count,
DURATION: [duration] * prompt_count,
LLM_STATUS: [llm_status] * prompt_count,
MODEL_NAME: [model_name] * prompt_count,
}
)

if token_usage_dict:
for key, value in token_usage_dict.items():
df[key] = [value] * prompt_count if isinstance(value, int) else value

try:
self.fiddler_client.publish_events_batch(self.project, self.model, df)
except Exception as e:
print( # noqa: T201
f"Error publishing events to fiddler: {e}. continuing..."
)

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> Any:
Expand All @@ -237,42 +301,36 @@ def on_llm_start(

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
flattened_llmresult = response.flatten()
token_usage_dict = {}
run_id = kwargs[RUN_ID]
run_duration = self.run_id_starttime[run_id] - int(time.time())
prompt_responses = []
model_name = ""
token_usage_dict = {}

if isinstance(response.llm_output, dict):
if TOKEN_USAGE in response.llm_output:
token_usage_dict = response.llm_output[TOKEN_USAGE]
if MODEL_NAME in response.llm_output:
model_name = response.llm_output[MODEL_NAME]

for llmresult in flattened_llmresult:
prompt_responses.append(llmresult.generations[0][0].text)

df = self.pd.DataFrame(
{
PROMPT: self.run_id_prompts[run_id],
RESPONSE: prompt_responses,
token_usage_dict = {
k: v
for k, v in response.llm_output.items()
if k in [TOTAL_TOKENS, PROMPT_TOKENS, COMPLETION_TOKENS]
}
)

if TOTAL_TOKENS in token_usage_dict:
df[PROMPT_TOKENS] = int(token_usage_dict[TOTAL_TOKENS])
model_name = response.llm_output.get(MODEL_NAME, "")

if PROMPT_TOKENS in token_usage_dict:
df[TOTAL_TOKENS] = int(token_usage_dict[PROMPT_TOKENS])
prompt_responses = [
llmresult.generations[0][0].text for llmresult in flattened_llmresult
]

if COMPLETION_TOKENS in token_usage_dict:
df[COMPLETION_TOKENS] = token_usage_dict[COMPLETION_TOKENS]
self._publish_events(
run_id,
prompt_responses,
run_duration,
SUCCESS,
model_name,
token_usage_dict,
)

df[MODEL_NAME] = model_name
df[RUN_ID] = str(run_id)
df[DURATION] = run_duration
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
run_id = kwargs[RUN_ID]
duration = int(time.time()) - self.run_id_starttime[run_id]

try:
self.fiddler_client.publish_events_batch(self.project, self.model, df)
except Exception as e:
print(f"Error publishing events to fiddler: {e}. continuing...") # noqa: T201
self._publish_events(
run_id, [""] * len(self.run_id_prompts[run_id]), duration, FAILURE
)

0 comments on commit e535829

Please sign in to comment.