Skip to content

Commit db61c53

Browse files
authoredSep 6, 2024··
feat (ai/core): middleware support (#2759)
1 parent 7ee8d32 commit db61c53

23 files changed

+886
-7
lines changed
 

‎.changeset/many-yaks-relate.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'ai': patch
3+
---
4+
5+
feat (ai/core): middleware support

‎content/docs/03-ai-sdk-core/40-provider-management.mdx

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description: Learn how to work with multiple providers
55

66
# Provider Management
77

8-
<Note>Provider management is an experimental feature.</Note>
8+
<Note type="warning">Provider management is an experimental feature.</Note>
99

1010
When you work with multiple providers and models, it is often desirable to manage them in a central place
1111
and access the models through simple string ids.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
---
2+
title: Language Model Middleware
3+
description: Learn how to use middleware to enhance the behavior of language models
4+
---
5+
6+
# Language Model Middleware
7+
8+
<Note type="warning">
9+
Language model middleware is an experimental feature.
10+
</Note>
11+
12+
Language model middleware is a way to enhance the behavior of language models
13+
by intercepting and modifying the calls to the language model.
14+
15+
It can be used to add features like guardrails, RAG, caching, and logging
16+
in a language model agnostic way. Such middleware can be developed and
17+
distributed independently from the language models that they are applied to.
18+
19+
## Using Language Model Middleware
20+
21+
You can use language model middleware with the `wrapLanguageModel` function.
22+
It takes a language model and a language model middleware and returns a new
23+
language model that incorporates the middleware.
24+
25+
```ts
26+
import { experimental_wrapLanguageModel as wrapLanguageModel } from 'ai';
27+
28+
const wrappedLanguageModel = wrapLanguageModel({
29+
model: yourModel,
30+
middleware: yourLanguageModelMiddleware,
31+
});
32+
```
33+
34+
The wrapped language model can be used just like any other language model, e.g. in `streamText`:
35+
36+
```ts highlight="2"
37+
const result = await streamText({
38+
model: wrappedLanguageModel,
39+
prompt: 'What cities are in the United States?',
40+
});
41+
```
42+
43+
## Implementing Language Model Middleware
44+
45+
<Note>
46+
Implementing language model middleware is advanced functionality and requires
47+
a solid understanding of the [language model
48+
specification](https://github.com/vercel/ai/blob/main/packages/provider/src/language-model/v1/language-model-v1.ts).
49+
</Note>
50+
51+
You can implement any of the following three function to modify the behavior of the language model:
52+
53+
1. `transformParams`: Transforms the parameters before they are passed to the language model, for both `doGenerate` and `doStream`.
54+
2. `wrapGenerate`: Wraps the `doGenerate` method of the [language model](https://github.com/vercel/ai/blob/main/packages/provider/src/language-model/v1/language-model-v1.ts).
55+
You can modify the parameters, call the language model, and modify the result.
56+
3. `wrapStream`: Wraps the `doStream` method of the [language model](https://github.com/vercel/ai/blob/main/packages/provider/src/language-model/v1/language-model-v1.ts).
57+
You can modify the parameters, call the language model, and modify the result.
58+
59+
Here are some examples of how to implement language model middleware:
60+
61+
## Examples
62+
63+
<Note>
64+
These examples are not meant to be used in production. They are just to show
65+
how you can use middleware to enhance the behavior of language models.
66+
</Note>
67+
68+
### Logging
69+
70+
This example shows how to log the parameters and generated text of a language model call.
71+
72+
```ts
73+
import type {
74+
Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware,
75+
LanguageModelV1StreamPart,
76+
} from 'ai';
77+
78+
export const yourLogMiddleware: LanguageModelV1Middleware = {
79+
wrapGenerate: async ({ doGenerate, params }) => {
80+
console.log('doGenerate called');
81+
console.log(`params: ${JSON.stringify(params, null, 2)}`);
82+
83+
const result = await doGenerate();
84+
85+
console.log('doGenerate finished');
86+
console.log(`generated text: ${result.text}`);
87+
88+
return result;
89+
},
90+
91+
wrapStream: async ({ doStream, params }) => {
92+
console.log('doStream called');
93+
console.log(`params: ${JSON.stringify(params, null, 2)}`);
94+
95+
const { stream, ...rest } = await doStream();
96+
97+
let generatedText = '';
98+
99+
const transformStream = new TransformStream<
100+
LanguageModelV1StreamPart,
101+
LanguageModelV1StreamPart
102+
>({
103+
transform(chunk, controller) {
104+
if (chunk.type === 'text-delta') {
105+
generatedText += chunk.textDelta;
106+
}
107+
108+
controller.enqueue(chunk);
109+
},
110+
111+
flush() {
112+
console.log('doStream finished');
113+
console.log(`generated text: ${generatedText}`);
114+
},
115+
});
116+
117+
return {
118+
stream: stream.pipeThrough(transformStream),
119+
...rest,
120+
};
121+
},
122+
};
123+
```
124+
125+
### Caching
126+
127+
This example shows how to build a simple cache for the generated text of a language model call.
128+
129+
```ts
130+
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';
131+
132+
const cache = new Map<string, any>();
133+
134+
export const yourCacheMiddleware: LanguageModelV1Middleware = {
135+
wrapGenerate: async ({ doGenerate, params }) => {
136+
const cacheKey = JSON.stringify(params);
137+
138+
if (cache.has(cacheKey)) {
139+
return cache.get(cacheKey);
140+
}
141+
142+
const result = await doGenerate();
143+
144+
cache.set(cacheKey, result);
145+
146+
return result;
147+
},
148+
149+
// here you would implement the caching logic for streaming
150+
};
151+
```
152+
153+
### Retrieval Augmented Generation (RAG)
154+
155+
This example shows how to use RAG as middleware.
156+
157+
<Note>
158+
Helper functions like `getLastUserMessageText` and `findSources` are not part
159+
of the AI SDK. They are just used in this example to illustrate the concept of
160+
RAG.
161+
</Note>
162+
163+
```ts
164+
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';
165+
166+
export const yourRagMiddleware: LanguageModelV1Middleware = {
167+
transformParams: async ({ params }) => {
168+
const lastUserMessageText = getLastUserMessageText({
169+
prompt: params.prompt,
170+
});
171+
172+
if (lastUserMessageText == null) {
173+
return params; // do not use RAG (send unmodified parameters)
174+
}
175+
176+
const instruction =
177+
'Use the following information to answer the question:\n' +
178+
findSources({ text: lastUserMessageText })
179+
.map(chunk => JSON.stringify(chunk))
180+
.join('\n');
181+
182+
return addToLastUserMessage({ params, text: instruction });
183+
},
184+
};
185+
```
186+
187+
### Guardrails
188+
189+
Guard rails are a way to ensure that the generated text of a language model call
190+
is safe and appropriate. This example shows how to use guardrails as middleware.
191+
192+
```ts
193+
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';
194+
195+
export const yourGuardrailMiddleware: LanguageModelV1Middleware = {
196+
wrapGenerate: async ({ doGenerate }) => {
197+
const { text, ...rest } = await doGenerate();
198+
199+
// filtering approach, e.g. for PII or other sensitive information:
200+
const cleanedText = text?.replace(/badword/g, '<REDACTED>');
201+
202+
return { text: cleanedText, ...rest };
203+
},
204+
205+
// here you would implement the guardrail logic for streaming
206+
// Note: streaming guardrails are difficult to implement, because
207+
// you do not know the full content of the stream until it's finished.
208+
};
209+
```

‎content/docs/07-reference/ai-sdk-core/40-provider-registry.mdx

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
---
2-
title: experimental_createProviderRegistry
2+
title: createProviderRegistry
33
description: Registry for managing multiple providers and models (API Reference)
44
---
55

6-
# `experimental_createProviderRegistry()`
6+
# `createProviderRegistry()`
77

8-
<Note>Provider management is an experimental feature.</Note>
8+
<Note type="warning">Provider management is an experimental feature.</Note>
99

1010
When you work with multiple providers and models, it is often desirable to manage them
1111
in a central place and access the models through simple string ids.

‎content/docs/07-reference/ai-sdk-core/42-custom-provider.mdx

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
---
2-
title: experimental_customProvider
2+
title: customProvider
33
description: Custom provider that uses models from a different provider (API Reference)
44
---
55

6-
# `experimental_customProvider()`
6+
# `customProvider()`
77

8-
<Note>Provider management is an experimental feature.</Note>
8+
<Note type="warning">Provider management is an experimental feature.</Note>
99

1010
With a custom provider, you can map ids to any model.
1111
This allows you to set up custom model configurations, alias names, and more.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
---
2+
title: wrapLanguageModel
3+
description: Function for wrapping a language model with middleware (API Reference)
4+
---
5+
6+
# `wrapLanguageModel()`
7+
8+
<Note type="warning">
9+
Language model middleware is an experimental feature.
10+
</Note>
11+
12+
The `experimental_wrapLanguageModel` function provides a way to enhance the behavior of language models
13+
by wrapping them with middleware.
14+
See [Language Model Middleware](/docs/ai-sdk-core/middleware) for more information on middleware.
15+
16+
```ts
17+
import { experimental_wrapLanguageModel as wrapLanguageModel } from 'ai';
18+
19+
const wrappedLanguageModel = wrapLanguageModel({
20+
model: yourModel,
21+
middleware: yourLanguageModelMiddleware,
22+
});
23+
```
24+
25+
## Import
26+
27+
<Snippet
28+
text={`import { experimental_wrapLanguageModel as wrapLanguageModel } from "ai"`}
29+
prompt={false}
30+
/>
31+
32+
## API Signature
33+
34+
### Parameters
35+
36+
<PropertiesTable
37+
content={[
38+
{
39+
name: 'model',
40+
type: 'LanguageModelV1',
41+
description: 'The original LanguageModelV1 instance to be wrapped.',
42+
},
43+
{
44+
name: 'middleware',
45+
type: 'Experimental_LanguageModelV1Middleware',
46+
description: 'The middleware to be applied to the language model.',
47+
},
48+
{
49+
name: 'modelId',
50+
type: 'string',
51+
description:
52+
"Optional custom model ID to override the original model's ID.",
53+
},
54+
{
55+
name: 'providerId',
56+
type: 'string',
57+
description:
58+
"Optional custom provider ID to override the original model's provider.",
59+
},
60+
]}
61+
/>
62+
63+
### Returns
64+
65+
A new `LanguageModelV1` instance with middleware applied.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
---
2+
title: LanguageModelV1Middleware
3+
description: Middleware for enhancing language model behavior (API Reference)
4+
---
5+
6+
# `LanguageModelV1Middleware`
7+
8+
<Note type="warning">
9+
Language model middleware is an experimental feature.
10+
</Note>
11+
12+
Language model middleware provides a way to enhance the behavior of language models
13+
by intercepting and modifying the calls to the language model. It can be used to add
14+
features like guardrails, RAG, caching, and logging in a language model agnostic way.
15+
16+
See [Language Model Middleware](/docs/ai-sdk-core/middleware) for more information.
17+
18+
## Import
19+
20+
<Snippet
21+
text={`import { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from "ai"`}
22+
prompt={false}
23+
/>
24+
25+
## API Signature
26+
27+
<PropertiesTable
28+
content={[
29+
{
30+
name: 'transformParams',
31+
type: '({ type: "generate" | "stream", params: LanguageModelV1CallOptions }) => Promise<LanguageModelV1CallOptions>',
32+
description:
33+
'Transforms the parameters before they are passed to the language model.',
34+
},
35+
{
36+
name: 'wrapGenerate',
37+
type: '({ doGenerate: DoGenerateFunction, params: LanguageModelV1CallOptions, model: LanguageModelV1 }) => Promise<DoGenerateResult>',
38+
description: 'Wraps the generate operation of the language model.',
39+
},
40+
{
41+
name: 'wrapStream',
42+
type: '({ doStream: DoStreamFunction, params: LanguageModelV1CallOptions, model: LanguageModelV1 }) => Promise<DoStreamResult>',
43+
description: 'Wraps the stream operation of the language model.',
44+
},
45+
]}
46+
/>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import { LanguageModelV1CallOptions } from 'ai';
2+
3+
export function addToLastUserMessage({
4+
text,
5+
params,
6+
}: {
7+
text: string;
8+
params: LanguageModelV1CallOptions;
9+
}): LanguageModelV1CallOptions {
10+
const { prompt, ...rest } = params;
11+
12+
const lastMessage = prompt.at(-1);
13+
14+
if (lastMessage?.role !== 'user') {
15+
return params;
16+
}
17+
18+
return {
19+
...rest,
20+
prompt: [
21+
...prompt.slice(0, -1),
22+
{
23+
...lastMessage,
24+
content: [{ type: 'text', text }, ...lastMessage.content],
25+
},
26+
],
27+
};
28+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import {
3+
generateText,
4+
experimental_wrapLanguageModel as wrapLanguageModel,
5+
} from 'ai';
6+
import 'dotenv/config';
7+
import { yourCacheMiddleware } from './your-cache-middleware';
8+
9+
async function main() {
10+
const modelWithCaching = wrapLanguageModel({
11+
model: openai('gpt-4o'),
12+
middleware: yourCacheMiddleware,
13+
});
14+
15+
const start1 = Date.now();
16+
const result1 = await generateText({
17+
model: modelWithCaching,
18+
prompt: 'What cities are in the United States?',
19+
});
20+
const end1 = Date.now();
21+
22+
const start2 = Date.now();
23+
const result2 = await generateText({
24+
model: modelWithCaching,
25+
prompt: 'What cities are in the United States?',
26+
});
27+
const end2 = Date.now();
28+
29+
console.log(`Time taken for result1: ${end1 - start1}ms`);
30+
console.log(`Time taken for result2: ${end2 - start2}ms`);
31+
32+
console.log(result1.text === result2.text);
33+
}
34+
35+
main().catch(console.error);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import {
3+
generateText,
4+
experimental_wrapLanguageModel as wrapLanguageModel,
5+
} from 'ai';
6+
import 'dotenv/config';
7+
import { yourLogMiddleware } from './your-log-middleware';
8+
9+
async function main() {
10+
const result = await generateText({
11+
model: wrapLanguageModel({
12+
model: openai('gpt-4o'),
13+
middleware: yourLogMiddleware,
14+
}),
15+
prompt: 'What cities are in the United States?',
16+
});
17+
}
18+
19+
main().catch(console.error);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import { LanguageModelV1Prompt } from 'ai';
2+
3+
export function getLastUserMessageText({
4+
prompt,
5+
}: {
6+
prompt: LanguageModelV1Prompt;
7+
}): string | undefined {
8+
const lastMessage = prompt.at(-1);
9+
10+
if (lastMessage?.role !== 'user') {
11+
return undefined;
12+
}
13+
14+
return lastMessage.content.length === 0
15+
? undefined
16+
: lastMessage.content.filter(c => c.type === 'text').join('\n');
17+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import {
3+
streamText,
4+
experimental_wrapLanguageModel as wrapLanguageModel,
5+
} from 'ai';
6+
import 'dotenv/config';
7+
import { yourLogMiddleware } from './your-log-middleware';
8+
9+
async function main() {
10+
const result = await streamText({
11+
model: wrapLanguageModel({
12+
model: openai('gpt-4o'),
13+
middleware: yourLogMiddleware,
14+
}),
15+
prompt: 'What cities are in the United States?',
16+
});
17+
18+
for await (const textPart of result.textStream) {
19+
// consume the stream
20+
}
21+
}
22+
23+
main().catch(console.error);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import { openai } from '@ai-sdk/openai';
2+
import {
3+
streamText,
4+
experimental_wrapLanguageModel as wrapLanguageModel,
5+
} from 'ai';
6+
import 'dotenv/config';
7+
import { yourRagMiddleware } from './your-rag-middleware';
8+
9+
async function main() {
10+
const result = await streamText({
11+
model: wrapLanguageModel({
12+
model: openai('gpt-4o'),
13+
middleware: yourRagMiddleware,
14+
}),
15+
prompt: 'What cities are in the United States?',
16+
});
17+
18+
for await (const textPart of result.textStream) {
19+
process.stdout.write(textPart);
20+
}
21+
}
22+
23+
main().catch(console.error);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';
2+
3+
const cache = new Map<string, any>();
4+
5+
export const yourCacheMiddleware: LanguageModelV1Middleware = {
6+
wrapGenerate: async ({ doGenerate, params }) => {
7+
const cacheKey = JSON.stringify(params);
8+
9+
if (cache.has(cacheKey)) {
10+
return cache.get(cacheKey);
11+
}
12+
13+
const result = await doGenerate();
14+
15+
cache.set(cacheKey, result);
16+
17+
return result;
18+
},
19+
20+
// here you would implement the caching logic for streaming
21+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';
2+
3+
export const yourGuardrailMiddleware: LanguageModelV1Middleware = {
4+
wrapGenerate: async ({ doGenerate }) => {
5+
const { text, ...rest } = await doGenerate();
6+
7+
// filtering approach, e.g. for PII or other sensitive information:
8+
const cleanedText = text?.replace(/badword/g, '<REDACTED>');
9+
10+
return { text: cleanedText, ...rest };
11+
},
12+
13+
// here you would implement the guardrail logic for streaming
14+
// Note: streaming guardrails are difficult to implement, because
15+
// you do not know the full content of the stream until it's finished.
16+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import type {
2+
Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware,
3+
LanguageModelV1StreamPart,
4+
} from 'ai';
5+
6+
export const yourLogMiddleware: LanguageModelV1Middleware = {
7+
wrapGenerate: async ({ doGenerate, params }) => {
8+
console.log('doGenerate called');
9+
console.log(`params: ${JSON.stringify(params, null, 2)}`);
10+
11+
const result = await doGenerate();
12+
13+
console.log('doGenerate finished');
14+
console.log(`generated text: ${result.text}`);
15+
16+
return result;
17+
},
18+
19+
wrapStream: async ({ doStream, params }) => {
20+
console.log('doStream called');
21+
console.log(`params: ${JSON.stringify(params, null, 2)}`);
22+
23+
const { stream, ...rest } = await doStream();
24+
25+
let generatedText = '';
26+
27+
const transformStream = new TransformStream<
28+
LanguageModelV1StreamPart,
29+
LanguageModelV1StreamPart
30+
>({
31+
transform(chunk, controller) {
32+
if (chunk.type === 'text-delta') {
33+
generatedText += chunk.textDelta;
34+
}
35+
36+
controller.enqueue(chunk);
37+
},
38+
39+
flush() {
40+
console.log('doStream finished');
41+
console.log(`generated text: ${generatedText}`);
42+
},
43+
});
44+
45+
return {
46+
stream: stream.pipeThrough(transformStream),
47+
...rest,
48+
};
49+
},
50+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import { addToLastUserMessage } from './add-to-last-user-message';
2+
import { getLastUserMessageText } from './get-last-user-message-text';
3+
import type { Experimental_LanguageModelV1Middleware as LanguageModelV1Middleware } from 'ai';
4+
5+
export const yourRagMiddleware: LanguageModelV1Middleware = {
6+
transformParams: async ({ params }) => {
7+
const lastUserMessageText = getLastUserMessageText({
8+
prompt: params.prompt,
9+
});
10+
11+
if (lastUserMessageText == null) {
12+
return params; // do not use RAG (send unmodified parameters)
13+
}
14+
15+
const instruction =
16+
'Use the following information to answer the question:\n' +
17+
findSources({ text: lastUserMessageText })
18+
.map(chunk => JSON.stringify(chunk))
19+
.join('\n');
20+
21+
return addToLastUserMessage({ params, text: instruction });
22+
},
23+
};
24+
25+
// example, could implement anything here:
26+
function findSources({ text }: { text: string }): Array<{
27+
title: string;
28+
previewText: string | undefined;
29+
url: string | undefined;
30+
}> {
31+
return [
32+
{
33+
title: 'New York',
34+
previewText: 'New York is a city in the United States.',
35+
url: 'https://en.wikipedia.org/wiki/New_York',
36+
},
37+
{
38+
title: 'San Francisco',
39+
previewText: 'San Francisco is a city in the United States.',
40+
url: 'https://en.wikipedia.org/wiki/San_Francisco',
41+
},
42+
];
43+
}

‎packages/ai/core/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ export type { DeepPartial, Schema } from '@ai-sdk/ui-utils';
33
export * from './embed';
44
export * from './generate-object';
55
export * from './generate-text';
6+
export * from './middleware';
67
export * from './prompt';
78
export * from './registry';
89
export * from './tool';

‎packages/ai/core/middleware/index.ts

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export type { Experimental_LanguageModelV1Middleware } from './language-model-v1-middleware';
2+
export { experimental_wrapLanguageModel } from './wrap-language-model';
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import { LanguageModelV1, LanguageModelV1CallOptions } from '@ai-sdk/provider';
2+
3+
/**
4+
* Experimental middleware for LanguageModelV1.
5+
* This type defines the structure for middleware that can be used to modify
6+
* the behavior of LanguageModelV1 operations.
7+
*/
8+
export type Experimental_LanguageModelV1Middleware = {
9+
/**
10+
* Transforms the parameters before they are passed to the language model.
11+
* @param options - Object containing the type of operation and the parameters.
12+
* @param options.type - The type of operation ('generate' or 'stream').
13+
* @param options.params - The original parameters for the language model call.
14+
* @returns A promise that resolves to the transformed parameters.
15+
*/
16+
transformParams?: (options: {
17+
type: 'generate' | 'stream';
18+
params: LanguageModelV1CallOptions;
19+
}) => PromiseLike<LanguageModelV1CallOptions>;
20+
21+
/**
22+
* Wraps the generate operation of the language model.
23+
* @param options - Object containing the generate function, parameters, and model.
24+
* @param options.doGenerate - The original generate function.
25+
* @param options.params - The parameters for the generate call. If the
26+
* `transformParams` middleware is used, this will be the transformed parameters.
27+
* @param options.model - The language model instance.
28+
* @returns A promise that resolves to the result of the generate operation.
29+
*/
30+
wrapGenerate?: (options: {
31+
doGenerate: () => ReturnType<LanguageModelV1['doGenerate']>;
32+
params: LanguageModelV1CallOptions;
33+
model: LanguageModelV1;
34+
}) => Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>>;
35+
36+
/**
37+
* Wraps the stream operation of the language model.
38+
* @param options - Object containing the stream function, parameters, and model.
39+
* @param options.doStream - The original stream function.
40+
* @param options.params - The parameters for the stream call. If the
41+
* `transformParams` middleware is used, this will be the transformed parameters.
42+
* @param options.model - The language model instance.
43+
* @returns A promise that resolves to the result of the stream operation.
44+
*/
45+
wrapStream?: (options: {
46+
doStream: () => ReturnType<LanguageModelV1['doStream']>;
47+
params: LanguageModelV1CallOptions;
48+
model: LanguageModelV1;
49+
}) => PromiseLike<Awaited<ReturnType<LanguageModelV1['doStream']>>>;
50+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import { LanguageModelV1CallOptions } from '@ai-sdk/provider';
2+
import { experimental_wrapLanguageModel } from '../middleware/wrap-language-model';
3+
import { MockLanguageModelV1 } from '../test/mock-language-model-v1';
4+
5+
it('should pass through model properties', () => {
6+
const wrappedModel = experimental_wrapLanguageModel({
7+
model: new MockLanguageModelV1({
8+
provider: 'test-provider',
9+
modelId: 'test-model',
10+
defaultObjectGenerationMode: 'json',
11+
supportsStructuredOutputs: true,
12+
}),
13+
middleware: {},
14+
});
15+
16+
expect(wrappedModel.provider).toBe('test-provider');
17+
expect(wrappedModel.modelId).toBe('test-model');
18+
expect(wrappedModel.defaultObjectGenerationMode).toBe('json');
19+
expect(wrappedModel.supportsStructuredOutputs).toBe(true);
20+
});
21+
22+
it('should override provider and modelId if provided', () => {
23+
const wrappedModel = experimental_wrapLanguageModel({
24+
model: new MockLanguageModelV1(),
25+
middleware: {},
26+
providerId: 'override-provider',
27+
modelId: 'override-model',
28+
});
29+
30+
expect(wrappedModel.provider).toBe('override-provider');
31+
expect(wrappedModel.modelId).toBe('override-model');
32+
});
33+
34+
it('should call transformParams middleware for doGenerate', async () => {
35+
const mockModel = new MockLanguageModelV1({
36+
doGenerate: vi.fn().mockResolvedValue('mock result'),
37+
});
38+
const transformParams = vi.fn().mockImplementation(({ params }) => ({
39+
...params,
40+
transformed: true,
41+
}));
42+
43+
const wrappedModel = experimental_wrapLanguageModel({
44+
model: mockModel,
45+
middleware: { transformParams },
46+
});
47+
48+
const params: LanguageModelV1CallOptions = {
49+
inputFormat: 'messages',
50+
prompt: [{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }],
51+
mode: { type: 'regular' },
52+
};
53+
54+
await wrappedModel.doGenerate(params);
55+
56+
expect(transformParams).toHaveBeenCalledWith({
57+
params,
58+
type: 'generate',
59+
});
60+
61+
expect(mockModel.doGenerate).toHaveBeenCalledWith({
62+
...params,
63+
transformed: true,
64+
});
65+
});
66+
67+
it('should call wrapGenerate middleware', async () => {
68+
const mockModel = new MockLanguageModelV1({
69+
doGenerate: vi.fn().mockResolvedValue('mock result'),
70+
});
71+
const wrapGenerate = vi
72+
.fn()
73+
.mockImplementation(({ doGenerate }) => doGenerate());
74+
75+
const wrappedModel = experimental_wrapLanguageModel({
76+
model: mockModel,
77+
middleware: { wrapGenerate },
78+
});
79+
80+
const params: LanguageModelV1CallOptions = {
81+
inputFormat: 'messages',
82+
prompt: [{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }],
83+
mode: { type: 'regular' },
84+
};
85+
86+
await wrappedModel.doGenerate(params);
87+
88+
expect(wrapGenerate).toHaveBeenCalledWith({
89+
doGenerate: expect.any(Function),
90+
params,
91+
model: mockModel,
92+
});
93+
});
94+
95+
it('should call transformParams middleware for doStream', async () => {
96+
const mockModel = new MockLanguageModelV1({
97+
doStream: vi.fn().mockResolvedValue('mock stream'),
98+
});
99+
const transformParams = vi.fn().mockImplementation(({ params }) => ({
100+
...params,
101+
transformed: true,
102+
}));
103+
104+
const wrappedModel = experimental_wrapLanguageModel({
105+
model: mockModel,
106+
middleware: { transformParams },
107+
});
108+
109+
const params: LanguageModelV1CallOptions = {
110+
inputFormat: 'messages',
111+
prompt: [{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }],
112+
mode: { type: 'regular' },
113+
};
114+
115+
await wrappedModel.doStream(params);
116+
117+
expect(transformParams).toHaveBeenCalledWith({
118+
params,
119+
type: 'stream',
120+
});
121+
expect(mockModel.doStream).toHaveBeenCalledWith({
122+
...params,
123+
transformed: true,
124+
});
125+
});
126+
127+
it('should call wrapStream middleware', async () => {
128+
const mockModel = new MockLanguageModelV1({
129+
doStream: vi.fn().mockResolvedValue('mock stream'),
130+
});
131+
const wrapStream = vi.fn().mockImplementation(({ doStream }) => doStream());
132+
133+
const wrappedModel = experimental_wrapLanguageModel({
134+
model: mockModel,
135+
middleware: { wrapStream },
136+
});
137+
138+
const params: LanguageModelV1CallOptions = {
139+
inputFormat: 'messages',
140+
prompt: [{ role: 'user', content: [{ type: 'text', text: 'Hello' }] }],
141+
mode: { type: 'regular' },
142+
};
143+
144+
await wrappedModel.doStream(params);
145+
146+
expect(wrapStream).toHaveBeenCalledWith({
147+
doStream: expect.any(Function),
148+
params,
149+
model: mockModel,
150+
});
151+
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import { LanguageModelV1, LanguageModelV1CallOptions } from '@ai-sdk/provider';
2+
import { Experimental_LanguageModelV1Middleware } from './language-model-v1-middleware';
3+
4+
/**
5+
* Wraps a LanguageModelV1 instance with middleware functionality.
6+
* This function allows you to apply middleware to transform parameters,
7+
* wrap generate operations, and wrap stream operations of a language model.
8+
*
9+
* @param options - Configuration options for wrapping the language model.
10+
* @param options.model - The original LanguageModelV1 instance to be wrapped.
11+
* @param options.middleware - The middleware to be applied to the language model.
12+
* @param options.modelId - Optional custom model ID to override the original model's ID.
13+
* @param options.providerId - Optional custom provider ID to override the original model's provider.
14+
* @returns A new LanguageModelV1 instance with middleware applied.
15+
*/
16+
export const experimental_wrapLanguageModel = ({
17+
model,
18+
middleware: { transformParams, wrapGenerate, wrapStream },
19+
modelId,
20+
providerId,
21+
}: {
22+
model: LanguageModelV1;
23+
middleware: Experimental_LanguageModelV1Middleware;
24+
modelId?: string;
25+
providerId?: string;
26+
}): LanguageModelV1 => {
27+
async function doTransform({
28+
params,
29+
type,
30+
}: {
31+
params: LanguageModelV1CallOptions;
32+
type: 'generate' | 'stream';
33+
}) {
34+
return transformParams ? await transformParams({ params, type }) : params;
35+
}
36+
37+
return {
38+
specificationVersion: 'v1',
39+
40+
provider: providerId ?? model.provider,
41+
modelId: modelId ?? model.modelId,
42+
43+
defaultObjectGenerationMode: model.defaultObjectGenerationMode,
44+
supportsImageUrls: model.supportsImageUrls,
45+
supportsStructuredOutputs: model.supportsStructuredOutputs,
46+
47+
async doGenerate(
48+
params: LanguageModelV1CallOptions,
49+
): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
50+
const transformedParams = await doTransform({ params, type: 'generate' });
51+
const doGenerate = async () => model.doGenerate(transformedParams);
52+
return wrapGenerate
53+
? wrapGenerate({ doGenerate, params: transformedParams, model })
54+
: doGenerate();
55+
},
56+
57+
async doStream(
58+
params: LanguageModelV1CallOptions,
59+
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
60+
const transformedParams = await doTransform({ params, type: 'stream' });
61+
const doStream = async () => model.doStream(transformedParams);
62+
return wrapStream
63+
? wrapStream({ doStream, params: transformedParams, model })
64+
: doStream();
65+
},
66+
};
67+
};

‎packages/ai/core/types/language-model.ts

+8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ import {
55
LanguageModelV1LogProbs,
66
} from '@ai-sdk/provider';
77

8+
// Re-export LanguageModelV1 types for the middleware:
9+
export type {
10+
LanguageModelV1,
11+
LanguageModelV1CallOptions,
12+
LanguageModelV1Prompt,
13+
LanguageModelV1StreamPart,
14+
} from '@ai-sdk/provider';
15+
816
/**
917
Language model that is used by the AI SDK Core functions.
1018
*/

0 commit comments

Comments
 (0)
Please sign in to comment.