Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Add support for cohere SDK v5 (keeps v4 backwards compatibility) #19084

Merged
merged 6 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 21 additions & 29 deletions docs/docs/integrations/chat/cohere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,10 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 1,
"id": "2108b517-1e8d-473d-92fa-4f930e8072a7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"········\n"
]
}
],
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
Expand Down Expand Up @@ -90,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
"metadata": {
"tags": []
Expand All @@ -103,7 +95,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
"metadata": {
"tags": []
Expand All @@ -115,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
"metadata": {
"tags": []
Expand All @@ -124,22 +116,22 @@
{
"data": {
"text/plain": [
"AIMessage(content=\"Who's there?\")"
"AIMessage(content=\"4! That's one, two, three, four. Keep adding and we'll reach new heights!\", response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 21, 'total_tokens': 94, 'billed_tokens': 25}})"
]
},
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [HumanMessage(content=\"knock knock\")]\n",
"messages = [HumanMessage(content=\"1\"), HumanMessage(content=\"2 3\")]\n",
"chat.invoke(messages)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b",
"metadata": {
"tags": []
Expand All @@ -148,10 +140,10 @@
{
"data": {
"text/plain": [
"AIMessage(content=\"Who's there?\")"
"AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})"
]
},
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -162,7 +154,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "025be980-e50d-4a68-93dc-c9c7b500ce34",
"metadata": {
"tags": []
Expand All @@ -172,7 +164,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Who's there?"
"4! It's a pleasure to be of service in this mathematical game."
]
}
],
Expand All @@ -183,17 +175,17 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"id": "064288e4-f184-4496-9427-bcf148fa055e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[AIMessage(content=\"Who's there?\")]"
"[AIMessage(content='4! According to the rules of addition, 1 + 2 equals 3, and 3 + 3 equals 6.', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 73, 'response_tokens': 28, 'total_tokens': 101, 'billed_tokens': 32}})]"
]
},
"execution_count": 6,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -214,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 9,
"id": "0851b103",
"metadata": {},
"outputs": [],
Expand All @@ -227,17 +219,17 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"id": "ae950c0f-1691-47f1-b609-273033cae707",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"Why did the bear go to the chiropractor?\\n\\nBecause she was feeling a bit grizzly!\\n\\nHope you found that joke about bears to be a little bit amusing! If you'd like to hear another one, just let me know. In the meantime, if you have any other questions or need assistance with a different topic, feel free to let me know. \\n\\nJust remember, even if you have a sore back like the bear, it's always best to consult a licensed professional for injuries or pain you may be experiencing. \\n\\nWould you like me to tell you another joke?\")"
"AIMessage(content='What do you call a bear with no teeth? A gummy bear!', response_metadata={'documents': None, 'citations': None, 'search_results': None, 'search_queries': None, 'token_count': {'prompt_tokens': 72, 'response_tokens': 14, 'total_tokens': 86, 'billed_tokens': 20}})"
]
},
"execution_count": 8,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -263,7 +255,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
15 changes: 14 additions & 1 deletion docs/docs/integrations/retrievers/cohere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@
"This notebook covers how to get started with Cohere RAG retriever. This allows you to leverage the ability to search documents over various connectors or by supplying your own."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c367be3",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -218,7 +231,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
17 changes: 15 additions & 2 deletions docs/docs/integrations/text_embedding/cohere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@
"Let's load the Cohere Embedding class."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1bfad19b",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand Down Expand Up @@ -50,7 +63,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.072631836, 0.06921387, -0.02658081, 0.022705078, 0.027328491, 0.046905518, -0.01838684, -0.029525757, 0.0041046143, -0.028198242, 0.0496521, 0.026901245, 0.03274536, 0.01574707, -0.081726074, -0.022369385, 0.049591064, 0.06549072, -0.015083313, -0.053863525, 0.098083496, 0.034698486, -0.08557129, -0.0024662018, -0.07519531, 0.03265381, 0.006046295, -0.0060691833, 0.032196045, 0.07537842, 9.024143e-05, -0.00869751, 0.022735596, 0.06329346, 0.068481445, -0.006778717, -0.07885742, 0.049560547, -0.008811951, 0.025253296, 0.050750732, -0.05343628, 0.051361084, -0.02319336, 0.026382446, 0.088378906, 0.03567505, -0.0736084, 0.039215088, -0.020584106, -0.03112793, -0.071777344, 0.018218994, -0.01876831, 0.040863037, 0.080078125, 0.046020508, -0.030792236, -0.011779785, -0.024871826, -0.06652832, 0.04748535, -0.038116455, 0.08453369, 0.08746338, 0.059509277, -0.037628174, -0.045410156, -0.054626465, -0.0036334991, -0.035949707, -0.011070251, 0.054534912, 0.0803833, 0.052734375, 0.06689453, 0.0074310303, 0.018249512, -0.023773193, 0.03845215, -0.113220215, 0.014251709, 0.028289795, -0.03942871, 0.029525757, 0.03036499, 0.035095215, 0.031829834, -0.0015306473, 0.027252197, 0.005088806, -0.035858154, -0.113220215, 0.021606445, 0.012046814, -0.06137085, 0.0057640076, -0.06994629, 0.02532959, 0.016952515, -0.010398865, -0.0066184998, -0.020904541, -0.12030029, 0.0036029816, -0.061553955, 0.023956299, -0.07330322, 0.013053894, -0.009613037, -0.062683105, 0.00013184547, 0.12030029, 0.028167725, 0.048614502, -0.09301758, -0.020324707, 0.022369385, -0.14025879, -0.052764893, 0.07220459, 0.028198242, 0.01499939, -0.029449463, 0.004711151, -0.05947876, 0.1640625, -0.09240723, 0.019500732, -0.0031089783, 0.0032081604, -0.0049934387, -0.01676941, 0.002691269, 0.02848816, 0.013504028, -0.057800293, 0.049041748, -0.022384644, 0.05517578, -0.031982422, 0.055389404, 0.0859375, 0.019866943, -0.052978516, 0.030929565, -0.15979004, 0.068481445, -0.020080566, -0.033477783, 0.07922363, -0.020736694, -0.025680542, 0.054016113, -0.028839111, -0.016189575, 0.03564453, 0.0001078248, 0.06304932, -0.022781372, 0.06555176, 0.010093689, 0.03286743, 0.14111328, -0.008468628, -0.04849243, 0.04525757, 0.065979004, -0.012138367, -0.017044067, 0.059509277, 0.035339355, -0.017807007, -0.027267456, -0.0034656525, -0.02078247, -0.033477783, 0.05041504, -0.043518066, -0.064208984, 0.034942627, -0.009300232, -0.08148193, 0.007774353, -0.03540039, -0.008255005, -0.1060791, -0.0703125, 0.091308594, 0.10095215, -0.081970215, 0.02355957, -0.026382446, -0.0070610046, -0.051208496, -0.014961243, 0.07269287, -0.033721924, 0.017669678, -0.08972168, 0.035339355, 0.03579712, -0.07299805, -0.014144897, -0.008850098, 0.023742676, -0.05847168, -0.07873535, -0.015388489, -0.039642334, -0.028930664, 0.008926392, -0.040283203, -0.02897644, -0.013557434, -0.006088257, 0.024169922, -0.10217285, 0.014526367, 0.007381439, -0.0005607605, -0.058410645, -0.008399963, -0.08001709, 0.05065918, 0.01727295, 0.012191772, -0.016571045, 0.03717041, -0.02607727, 0.060760498, 0.057678223, -0.06585693, 0.059173584, 0.023117065, -0.034118652, -0.03189087, 0.010429382, 0.010368347, -0.011230469, -0.020980835, -0.04019165, 0.048187256, -0.019638062, -0.024414062, -0.0019989014, 0.04336548, 0.117248535, 0.00033903122, -0.0014419556, 0.013946533, -0.11541748, 0.030059814, -0.06500244, 0.05441284, 0.021759033, 0.030380249, 0.080566406, 0.02331543, -0.04586792, 0.037322998, 0.011390686, -0.01374054, 0.1459961, -0.050964355, 0.081970215, -0.061645508, 0.07067871, -0.036956787, 0.060455322, 0.051361084, -0.05831909, 0.05328369, -0.008628845, 0.054534912, -0.047332764, 0.030578613, -0.048828125, -0.018112183, 0.022979736, -0.07318115, -0.0423584, -0.094177246, -0.04071045, 0.054260254, 0.0423584, 0.075805664, -0.06365967, 0.009269714, -0.054779053, -0.007637024, -0.01876831, 0.08453369, 0.058898926, -0.07727051, 0.04360962, 0.010574341, -0.027694702, 0.024917603, -0.0463562, 0.040222168, -0.05496216, -0.048461914, 0.013710022, -0.1038208, 0.027954102, 0.031951904, -0.05618286, 0.0025730133, -0.06549072, -0.049957275, 0.01499939, -0.11090088, -0.009017944, 0.021835327, 0.03503418, 0.058746338, -0.12756348, -0.0345459, -0.04699707, -0.029830933, -0.06726074, 0.010612488, -0.024108887, 0.016464233, 0.013076782, -0.06298828, -0.0657959, -0.0025234222, -0.0625, 0.013420105, 0.05810547, -0.006362915, -0.028625488, 0.06085205, 0.12310791, 0.04751587, -0.027740479, -0.02029419, -0.02293396, 0.048858643, -0.006793976, -0.0061073303, 0.029067993, -0.0076942444, -0.00088596344, -0.007446289, 0.12756348, 0.082092285, -0.0037841797, 0.03866577, 0.040374756, 0.019104004, -0.0345459, 0.019042969, -0.038116455, 0.045410156, 0.062683105, -0.024963379, 0.085632324, 0.005897522, 0.008285522, 0.008811951, 0.026504517, 0.025558472, -0.005554199, -0.017822266, -0.112854004, -0.03768921, -0.00097227097, -0.061401367, 0.050567627, -0.010734558, 0.07220459, 0.03643799, 0.0007662773, -0.020980835, -0.04711914, -0.03488159, -0.09655762, 0.0048561096, 0.028030396, 0.04586792, -0.014915466]\n"
"[-0.09338379, 0.0871582, -0.03326416, 0.01953125, 0.07702637, 0.034729004, -0.058380127, -0.031021118, -0.030517578, -0.055999756, 0.050842285, -0.006752014, 0.038391113, -0.0014362335, -0.041137695, -0.008880615, 0.026000977, -0.023010254, 0.05456543, -0.03366089, 0.055633545, 0.028579712, -0.068603516, 0.03970337, -0.06677246, 0.06732178, -0.013053894, -0.0060920715, 0.038116455, 0.057800293, 0.048736572, 0.026855469, 0.009849548, 0.08312988, 0.073791504, 0.01663208, -0.0871582, 0.01802063, -0.0020828247, -0.0031356812, 0.039978027, -0.03164673, 0.009796143, 0.011375427, 0.0068855286, 0.092285156, 0.05218506, -0.060943604, 0.038269043, -0.018218994, -0.04510498, -0.0847168, 0.008300781, -0.060058594, 0.0012111664, 0.05102539, 0.05218506, -0.047210693, -0.051239014, -0.044158936, -0.058166504, 0.07849121, -0.019165039, 0.06451416, 0.024887085, 0.011405945, -0.03768921, -0.018814087, -0.06829834, -0.052825928, -0.019104004, -0.021194458, 0.043518066, 0.07525635, 0.082336426, 0.0037651062, -0.0060310364, -0.03265381, 0.011375427, -0.013847351, -0.07232666, 0.02986145, 0.03866577, -0.029083252, 0.008666992, 0.03845215, 0.045196533, 0.012756348, -0.018051147, 0.032440186, -0.030715942, -0.045440674, -0.11187744, 0.032073975, 0.021972656, -0.044921875, -0.030410767, -0.03668213, 0.12420654, 0.05029297, -0.032989502, -0.049438477, 0.001704216, -0.08074951, 0.00046396255, -0.04107666, 0.020599365, -0.089416504, 0.020477295, -0.038726807, -0.04437256, -0.019256592, 0.048583984, 0.046020508, 0.03741455, -0.037475586, -0.050720215, 0.052856445, -0.10229492, -0.00010281801, 0.058776855, 0.021453857, -0.031051636, 0.01676941, 0.024047852, -0.026306152, 0.15258789, -0.09979248, 0.04888916, 0.045166016, 0.008865356, -0.043914795, -0.032928467, 0.0052757263, 0.06072998, 0.036956787, -0.058013916, 0.053466797, -0.03225708, 0.018371582, -0.0042533875, 0.047943115, 0.06530762, 0.039855957, -0.025360107, 0.047332764, -0.15124512, 0.08325195, 0.016174316, -0.029724121, 0.111816406, -0.05230713, -0.06964111, 0.03060913, -0.04257202, -0.0284729, 0.007843018, -0.03866577, 0.07867432, -0.04446411, 0.028869629, -0.015823364, 0.02659607, 0.085754395, 0.03878784, -0.04232788, 0.017074585, 0.026779175, -0.04284668, -0.017105103, 0.10058594, 0.022323608, -0.007007599, -0.09661865, -0.01322937, -0.004627228, 0.057800293, 0.057159424, -0.033294678, -0.066101074, 0.010910034, 0.033569336, -0.062042236, -0.0072021484, -0.070373535, 0.034729004, -0.07434082, -0.06604004, 0.061401367, 0.09576416, -0.070739746, 0.066833496, -0.019042969, -0.0051994324, -0.07696533, -0.03564453, 0.048614502, -0.048919678, 0.036224365, -0.06652832, 0.03338623, 0.05847168, 0.009414673, -0.035095215, 0.011787415, -0.007675171, -0.057006836, -0.045074463, -0.027999878, -0.049102783, -0.025787354, -0.010101318, -0.000813961, -0.009963989, -0.013343811, 0.04046631, 0.02758789, -0.07086182, 0.09442139, -0.012275696, -0.018936157, -0.011940002, 0.10638428, -0.10913086, 0.05606079, 0.008895874, 0.017089844, 0.019958496, 0.03173828, -0.037322998, 0.019699097, 0.046722412, -0.08959961, 0.059448242, 0.018875122, -0.057495117, -0.039276123, 0.009063721, -0.0178833, 0.032073975, -0.08178711, -0.061431885, 0.05731201, 0.012886047, -0.025360107, 0.04498291, 0.027923584, 0.125, 0.013374329, -0.013069153, -0.031677246, -0.109558105, 0.05731201, -0.03765869, 0.04650879, -0.005706787, 0.021697998, -0.0008239746, 0.030090332, -0.048736572, 0.07940674, -0.017120361, 0.018737793, 0.12011719, -0.03564453, 0.07519531, -0.039611816, -0.014968872, -0.045288086, 0.07702637, 0.010681152, -0.04736328, 0.07623291, 0.008071899, 0.080078125, -0.060516357, 0.043426514, -0.026489258, -0.018188477, 0.049560547, -0.068847656, -0.03387451, -0.09661865, -0.03768921, 0.028549194, 0.036621094, 0.05307007, -0.053894043, 0.0019035339, -0.07788086, -0.010597229, -0.027420044, 0.10900879, 0.019302368, -0.06726074, 0.04937744, 0.05154419, -0.050598145, 0.07562256, -0.05569458, 0.073913574, -0.052337646, -0.0149383545, -0.00037050247, 0.037322998, 0.018478394, -0.03201294, -0.04788208, 0.03062439, -0.055786133, 0.0018081665, 0.029510498, -0.10864258, -0.027374268, 0.040405273, 0.01474762, -0.010726929, -0.086242676, -0.02658081, -0.057159424, -0.0095825195, -0.11804199, -0.014289856, -0.006881714, -0.028533936, 0.005382538, -0.053771973, -0.015853882, 0.0034332275, -0.08441162, -0.028182983, -0.00856781, -0.060394287, -0.036590576, 0.03062439, 0.112854004, -0.008041382, -0.03353882, 0.0181427, -0.03466797, 0.026565552, -0.033813477, 0.0074310303, -0.02017212, -0.047729492, 0.00010108948, -0.032073975, 0.08630371, 0.08557129, -0.0115737915, 0.044067383, 0.062042236, 0.00819397, -0.016082764, 0.01574707, 0.0154418945, 0.06726074, 0.056884766, 0.01210022, 0.048095703, -0.0017309189, 0.018295288, -0.00592041, 0.062286377, 0.040649414, -0.032928467, -0.05392456, -0.13891602, -0.033050537, 0.047973633, -0.07824707, 0.024627686, -0.02923584, 0.09118652, 0.0690918, 0.045837402, -0.06402588, -0.028747559, -0.06542969, -0.08496094, 0.06762695, 0.04220581, 0.059539795, 0.0023174286]\n"
]
}
],
Expand Down Expand Up @@ -103,7 +116,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.7"
},
"vscode": {
"interpreter": {
Expand Down
18 changes: 14 additions & 4 deletions libs/community/langchain_community/chat_models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_cohere_chat_request(
"AUTO" if documents is not None or connectors is not None else None
)

return {
req = {
"message": messages[-1].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1]
Expand All @@ -91,6 +91,8 @@ def get_cohere_chat_request(
**kwargs,
}

return {k: v for k, v in req.items() if v is not None}


class ChatCohere(BaseChatModel, BaseCohere):
"""`Cohere` chat large language models.
Expand Down Expand Up @@ -142,7 +144,11 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = self.client.chat(**request, stream=True)

if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
else:
stream = self.client.chat(**request, stream=True)

for data in stream:
if data.event_type == "text-generation":
Expand All @@ -160,7 +166,11 @@ async def _astream(
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = await self.async_client.chat(**request, stream=True)

if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = self.async_client.chat_stream(**request)
else:
stream = self.async_client.chat(**request, stream=True)

async for data in stream:
if data.event_type == "text-generation":
Expand Down Expand Up @@ -220,7 +230,7 @@ async def _agenerate(
return await agenerate_from_stream(stream_iter)

request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request, stream=False)
response = self.client.chat(**request)

message = AIMessage(content=response.text)
generation_info = None
Expand Down