Skip to content

Commit 3be7c1c

Browse files
authoredSep 4, 2024··
fix (provider/anthropic): support prompt caching on assistant messages (#2890)
1 parent 8b19cf0 commit 3be7c1c

11 files changed

+248
-17
lines changed
 

‎.changeset/two-boxes-attend.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
'@ai-sdk/anthropic': patch
3+
'@ai-sdk/provider': patch
4+
'ai': patch
5+
---
6+
7+
fix (provider/anthropic): support prompt caching on assistant messages

‎packages/ai/core/generate-text/generate-text.test.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ describe('result.responseMessages', () => {
390390
doGenerate: async ({ prompt, mode }) => {
391391
switch (responseCount++) {
392392
case 0:
393-
assert.deepStrictEqual(mode, {
393+
expect(mode).toStrictEqual({
394394
type: 'regular',
395395
toolChoice: { type: 'auto' },
396396
tools: [
@@ -408,7 +408,8 @@ describe('result.responseMessages', () => {
408408
},
409409
],
410410
});
411-
assert.deepStrictEqual(prompt, [
411+
412+
expect(prompt).toStrictEqual([
412413
{
413414
role: 'user',
414415
content: [{ type: 'text', text: 'test-input' }],
@@ -441,7 +442,7 @@ describe('result.responseMessages', () => {
441442
},
442443
};
443444
case 1:
444-
assert.deepStrictEqual(mode, {
445+
expect(mode).toStrictEqual({
445446
type: 'regular',
446447
toolChoice: { type: 'auto' },
447448
tools: [
@@ -459,7 +460,8 @@ describe('result.responseMessages', () => {
459460
},
460461
],
461462
});
462-
assert.deepStrictEqual(prompt, [
463+
464+
expect(prompt).toStrictEqual([
463465
{
464466
role: 'user',
465467
content: [
@@ -477,6 +479,7 @@ describe('result.responseMessages', () => {
477479
toolCallId: 'call-1',
478480
toolName: 'tool1',
479481
args: { value: 'value' },
482+
providerMetadata: undefined,
480483
},
481484
],
482485
providerMetadata: undefined,

‎packages/ai/core/generate-text/stream-text.test.ts

+7-4
Original file line numberDiff line numberDiff line change
@@ -1997,7 +1997,7 @@ describe('options.maxToolRoundtrips', () => {
19971997
doStream: async ({ prompt, mode }) => {
19981998
switch (responseCount++) {
19991999
case 0:
2000-
assert.deepStrictEqual(mode, {
2000+
expect(mode).toStrictEqual({
20012001
type: 'regular',
20022002
tools: [
20032003
{
@@ -2015,7 +2015,8 @@ describe('options.maxToolRoundtrips', () => {
20152015
],
20162016
toolChoice: { type: 'auto' },
20172017
});
2018-
assert.deepStrictEqual(prompt, [
2018+
2019+
expect(prompt).toStrictEqual([
20192020
{
20202021
role: 'user',
20212022
content: [{ type: 'text', text: 'test-input' }],
@@ -2041,7 +2042,7 @@ describe('options.maxToolRoundtrips', () => {
20412042
rawCall: { rawPrompt: 'prompt', rawSettings: {} },
20422043
};
20432044
case 1:
2044-
assert.deepStrictEqual(mode, {
2045+
expect(mode).toStrictEqual({
20452046
type: 'regular',
20462047
tools: [
20472048
{
@@ -2059,7 +2060,8 @@ describe('options.maxToolRoundtrips', () => {
20592060
],
20602061
toolChoice: { type: 'auto' },
20612062
});
2062-
assert.deepStrictEqual(prompt, [
2063+
2064+
expect(prompt).toStrictEqual([
20632065
{
20642066
role: 'user',
20652067
content: [{ type: 'text', text: 'test-input' }],
@@ -2072,6 +2074,7 @@ describe('options.maxToolRoundtrips', () => {
20722074
toolCallId: 'call-1',
20732075
toolName: 'tool1',
20742076
args: { value: 'value' },
2077+
providerMetadata: undefined,
20752078
},
20762079
],
20772080
providerMetadata: undefined,

‎packages/ai/core/prompt/content-part.ts

+7
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ Name of the tool that is being called.
8484
Arguments of the tool call. This is a JSON-serializable object that matches the tool's input schema.
8585
*/
8686
args: unknown;
87+
88+
/**
89+
Additional provider-specific metadata. They are passed through
90+
to the provider from the AI SDK and enable provider-specific
91+
functionality that can be fully encapsulated in the provider.
92+
*/
93+
experimental_providerMetadata?: ProviderMetadata;
8794
}
8895

8996
export const toolCallPartSchema: z.ZodType<ToolCallPart> = z.object({

‎packages/ai/core/prompt/convert-to-language-model-prompt.test.ts

+43
Original file line numberDiff line numberDiff line change
@@ -230,5 +230,48 @@ describe('convertToLanguageModelMessage', () => {
230230
});
231231
});
232232
});
233+
234+
describe('tool call parts', () => {
235+
it('should pass through provider metadata', () => {
236+
const result = convertToLanguageModelMessage(
237+
{
238+
role: 'assistant',
239+
content: [
240+
{
241+
type: 'tool-call',
242+
toolName: 'toolName',
243+
toolCallId: 'toolCallId',
244+
args: {},
245+
experimental_providerMetadata: {
246+
'test-provider': {
247+
'key-a': 'test-value-1',
248+
'key-b': 'test-value-2',
249+
},
250+
},
251+
},
252+
],
253+
},
254+
null,
255+
);
256+
257+
expect(result).toEqual({
258+
role: 'assistant',
259+
content: [
260+
{
261+
type: 'tool-call',
262+
args: {},
263+
toolCallId: 'toolCallId',
264+
toolName: 'toolName',
265+
providerMetadata: {
266+
'test-provider': {
267+
'key-a': 'test-value-1',
268+
'key-b': 'test-value-2',
269+
},
270+
},
271+
},
272+
],
273+
});
274+
});
275+
});
233276
});
234277
});

‎packages/ai/core/prompt/convert-to-language-model-prompt.ts

+12-4
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,18 @@ export function convertToLanguageModelMessage(
219219

220220
return {
221221
role: 'assistant',
222-
content: message.content.filter(
223-
// remove empty text parts:
224-
part => part.type !== 'text' || part.text !== '',
225-
),
222+
content: message.content
223+
.filter(
224+
// remove empty text parts:
225+
part => part.type !== 'text' || part.text !== '',
226+
)
227+
.map(part => {
228+
const { experimental_providerMetadata, ...rest } = part;
229+
return {
230+
...rest,
231+
providerMetadata: experimental_providerMetadata,
232+
};
233+
}),
226234
providerMetadata: message.experimental_providerMetadata,
227235
};
228236
}

‎packages/anthropic/src/anthropic-messages-prompt.ts

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ export interface AnthropicToolCallContent {
4040
id: string;
4141
name: string;
4242
input: unknown;
43+
cache_control?: AnthropicCacheControl;
4344
}
4445

4546
export interface AnthropicToolResultContent {

‎packages/anthropic/src/convert-to-anthropic-messages-prompt.test.ts

+129
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,135 @@ describe('cache control', () => {
405405
});
406406
});
407407

408+
describe('assistant message', () => {
409+
it('should set cache_control on assistant message text part with part cache control', async () => {
410+
const result = convertToAnthropicMessagesPrompt({
411+
prompt: [
412+
{ role: 'user', content: [{ type: 'text', text: 'user-content' }] },
413+
{
414+
role: 'assistant',
415+
content: [
416+
{
417+
type: 'text',
418+
text: 'test',
419+
providerMetadata: {
420+
anthropic: {
421+
cacheControl: { type: 'ephemeral' },
422+
},
423+
},
424+
},
425+
],
426+
},
427+
],
428+
cacheControl: true,
429+
});
430+
431+
expect(result).toEqual({
432+
messages: [
433+
{ role: 'user', content: [{ type: 'text', text: 'user-content' }] },
434+
{
435+
role: 'assistant',
436+
content: [
437+
{
438+
type: 'text',
439+
text: 'test',
440+
cache_control: { type: 'ephemeral' },
441+
},
442+
],
443+
},
444+
],
445+
system: undefined,
446+
});
447+
});
448+
449+
it('should set cache_control on assistant tool call part with part cache control', async () => {
450+
const result = convertToAnthropicMessagesPrompt({
451+
prompt: [
452+
{ role: 'user', content: [{ type: 'text', text: 'user-content' }] },
453+
{
454+
role: 'assistant',
455+
content: [
456+
{
457+
type: 'tool-call',
458+
toolCallId: 'test-id',
459+
toolName: 'test-tool',
460+
args: { some: 'arg' },
461+
providerMetadata: {
462+
anthropic: {
463+
cacheControl: { type: 'ephemeral' },
464+
},
465+
},
466+
},
467+
],
468+
},
469+
],
470+
cacheControl: true,
471+
});
472+
473+
expect(result).toEqual({
474+
messages: [
475+
{ role: 'user', content: [{ type: 'text', text: 'user-content' }] },
476+
{
477+
role: 'assistant',
478+
content: [
479+
{
480+
type: 'tool_use',
481+
name: 'test-tool',
482+
id: 'test-id',
483+
input: { some: 'arg' },
484+
cache_control: { type: 'ephemeral' },
485+
},
486+
],
487+
},
488+
],
489+
system: undefined,
490+
});
491+
});
492+
493+
it('should set cache_control on last assistant message part with message cache control', async () => {
494+
const result = convertToAnthropicMessagesPrompt({
495+
prompt: [
496+
{ role: 'user', content: [{ type: 'text', text: 'user-content' }] },
497+
{
498+
role: 'assistant',
499+
content: [
500+
{ type: 'text', text: 'part1' },
501+
{ type: 'text', text: 'part2' },
502+
],
503+
providerMetadata: {
504+
anthropic: {
505+
cacheControl: { type: 'ephemeral' },
506+
},
507+
},
508+
},
509+
],
510+
cacheControl: true,
511+
});
512+
513+
expect(result).toEqual({
514+
messages: [
515+
{ role: 'user', content: [{ type: 'text', text: 'user-content' }] },
516+
{
517+
role: 'assistant',
518+
content: [
519+
{
520+
type: 'text',
521+
text: 'part1',
522+
cache_control: undefined,
523+
},
524+
{
525+
type: 'text',
526+
text: 'part2',
527+
cache_control: { type: 'ephemeral' },
528+
},
529+
],
530+
},
531+
],
532+
system: undefined,
533+
});
534+
});
535+
});
536+
408537
describe('tool message', () => {
409538
it('should set cache_control on tool result message part with part cache control', async () => {
410539
const result = convertToAnthropicMessagesPrompt({

‎packages/anthropic/src/convert-to-anthropic-messages-prompt.ts

+20-5
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ export function convertToAnthropicMessagesPrompt({
7272
const { role, content } = message;
7373
switch (role) {
7474
case 'user': {
75-
for (let i = 0; i < content.length; i++) {
76-
const part = content[i];
75+
for (let j = 0; j < content.length; j++) {
76+
const part = content[j];
7777

7878
// cache control: first add cache control from part.
7979
// for the last part of a message,
8080
// check also if the message has cache control.
81-
const isLastPart = i === content.length - 1;
81+
const isLastPart = j === content.length - 1;
8282

8383
const cacheControl =
8484
getCacheControl(part.providerMetadata) ??
@@ -162,9 +162,23 @@ export function convertToAnthropicMessagesPrompt({
162162
// combines multiple assistant messages in this block into a single message:
163163
const anthropicContent: AnthropicAssistantMessage['content'] = [];
164164

165-
for (const { content } of block.messages) {
165+
for (const message of block.messages) {
166+
const { content } = message;
167+
166168
for (let j = 0; j < content.length; j++) {
167169
const part = content[j];
170+
171+
// cache control: first add cache control from part.
172+
// for the last part of a message,
173+
// check also if the message has cache control.
174+
const isLastPart = j === content.length - 1;
175+
176+
const cacheControl =
177+
getCacheControl(part.providerMetadata) ??
178+
(isLastPart
179+
? getCacheControl(message.providerMetadata)
180+
: undefined);
181+
168182
switch (part.type) {
169183
case 'text': {
170184
anthropicContent.push({
@@ -177,7 +191,7 @@ export function convertToAnthropicMessagesPrompt({
177191
? part.text.trim()
178192
: part.text,
179193

180-
cache_control: undefined, // not used in assistant messages
194+
cache_control: cacheControl,
181195
});
182196
break;
183197
}
@@ -188,6 +202,7 @@ export function convertToAnthropicMessagesPrompt({
188202
id: part.toolCallId,
189203
name: part.toolName,
190204
input: part.args,
205+
cache_control: cacheControl,
191206
});
192207
break;
193208
}

‎packages/provider/src/language-model/v1/language-model-v1-call-options.ts

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { JSONSchema7 } from 'json-schema';
22
import { LanguageModelV1CallSettings } from './language-model-v1-call-settings';
33
import { LanguageModelV1FunctionTool } from './language-model-v1-function-tool';
44
import { LanguageModelV1Prompt } from './language-model-v1-prompt';
5+
import { LanguageModelV1ProviderMetadata } from './language-model-v1-provider-metadata';
56
import { LanguageModelV1ToolChoice } from './language-model-v1-tool-choice';
67

78
export type LanguageModelV1CallOptions = LanguageModelV1CallSettings & {
@@ -76,4 +77,11 @@ That approach allows us to evolve the user facing prompts without breaking
7677
the language model interface.
7778
*/
7879
prompt: LanguageModelV1Prompt;
80+
81+
/**
82+
* Additional provider-specific metadata. They are passed through
83+
* to the provider from the AI SDK and enable provider-specific
84+
* functionality that can be fully encapsulated in the provider.
85+
*/
86+
providerMetadata?: LanguageModelV1ProviderMetadata;
7987
};

‎packages/provider/src/language-model/v1/language-model-v1-prompt.ts

+7
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ Name of the tool that is being called.
104104
Arguments of the tool call. This is a JSON-serializable object that matches the tool's input schema.
105105
*/
106106
args: unknown;
107+
108+
/**
109+
* Additional provider-specific metadata. They are passed through
110+
* to the provider from the AI SDK and enable provider-specific
111+
* functionality that can be fully encapsulated in the provider.
112+
*/
113+
providerMetadata?: LanguageModelV1ProviderMetadata;
107114
}
108115

109116
/**

0 commit comments

Comments
 (0)
Please sign in to comment.