Skip to content

Commit

Permalink
docs: Update list of chat models tool calling providers (#20330)
Browse files Browse the repository at this point in the history
Will follow up with a few missing providers
  • Loading branch information
eyurtsev committed Apr 11, 2024
1 parent 653489a commit 0e74fb4
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions docs/scripts/model_feat_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
"PromptLayerOpenAI": {"batch_generate": False, "batch_agenerate": False},
}
CHAT_MODEL_IGNORE = ("FakeListChatModel", "HumanInputChatModel")

CHAT_MODEL_FEAT_TABLE_CORRECTION = {
"ChatMLflowAIGateway": {"_agenerate": False},
"PromptLayerChatOpenAI": {"_stream": False, "_astream": False},
"ChatKonko": {"_astream": False, "_agenerate": False},
"ChatOpenAI": {"tool_calling": True},
"ChatAnthropic": {"tool_calling": True},
"ChatMistralAI": {"tool_calling": True},
}


LLM_TEMPLATE = """\
---
sidebar_position: 1
Expand Down Expand Up @@ -101,6 +106,7 @@ def get_llm_table():
"_astream",
"batch_generate",
"batch_agenerate",
"tool_calling",
]
title = [
"Model",
Expand All @@ -110,14 +116,16 @@ def get_llm_table():
"Async stream",
"Batch",
"Async batch",
"Tool calling",
]
rows = [title, [":-"] + [":-:"] * (len(title) - 1)]
for llm, feats in sorted(final_feats.items()):
rows += [[llm, "βœ…"] + ["βœ…" if feats.get(h) else "❌" for h in header[1:]]]
return "\n".join(["|".join(row) for row in rows])


def get_chat_model_table():
def get_chat_model_table() -> str:
"""Get the table of chat models."""
feat_table = {}
for cm in chat_models.__all__:
feat_table[cm] = {}
Expand All @@ -133,8 +141,15 @@ def get_chat_model_table():
for k, v in {**feat_table, **CHAT_MODEL_FEAT_TABLE_CORRECTION}.items()
if k not in CHAT_MODEL_IGNORE
}
header = ["model", "_agenerate", "_stream", "_astream"]
title = ["Model", "Invoke", "Async invoke", "Stream", "Async stream"]
header = ["model", "_agenerate", "_stream", "_astream", "tool_calling"]
title = [
"Model",
"Invoke",
"Async invoke",
"Stream",
"Async stream",
"Tool calling",
]
rows = [title, [":-"] + [":-:"] * (len(title) - 1)]
for llm, feats in sorted(final_feats.items()):
rows += [[llm, "βœ…"] + ["βœ…" if feats.get(h) else "❌" for h in header[1:]]]
Expand Down

0 comments on commit 0e74fb4

Please sign in to comment.