Skip to content

Commit ca62400

Browse files
authoredApr 29, 2024··
Allow user to provide systemInstruction as string or Part (#113)
1 parent 111e970 commit ca62400

12 files changed

+273
-13
lines changed
 

‎.changeset/perfect-hotels-protect.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@google/generative-ai": minor
3+
---
4+
5+
Allow text-only systemInstruction as well as Part and Content.

‎docs/reference/main/generative-ai.generatecontentrequest.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export interface GenerateContentRequest extends BaseParams
1818
| Property | Modifiers | Type | Description |
1919
| --- | --- | --- | --- |
2020
| [contents](./generative-ai.generatecontentrequest.contents.md) | | [Content](./generative-ai.content.md)<!-- -->\[\] | |
21-
| [systemInstruction?](./generative-ai.generatecontentrequest.systeminstruction.md) | | [Content](./generative-ai.content.md) | _(Optional)_ |
21+
| [systemInstruction?](./generative-ai.generatecontentrequest.systeminstruction.md) | | string \| [Part](./generative-ai.part.md) \| [Content](./generative-ai.content.md) | _(Optional)_ |
2222
| [toolConfig?](./generative-ai.generatecontentrequest.toolconfig.md) | | [ToolConfig](./generative-ai.toolconfig.md) | _(Optional)_ |
2323
| [tools?](./generative-ai.generatecontentrequest.tools.md) | | [Tool](./generative-ai.tool.md)<!-- -->\[\] | _(Optional)_ |
2424

‎docs/reference/main/generative-ai.generatecontentrequest.systeminstruction.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
**Signature:**
88

99
```typescript
10-
systemInstruction?: Content;
10+
systemInstruction?: string | Part | Content;
1111
```

‎docs/reference/main/generative-ai.modelparams.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export interface ModelParams extends BaseParams
1818
| Property | Modifiers | Type | Description |
1919
| --- | --- | --- | --- |
2020
| [model](./generative-ai.modelparams.model.md) | | string | |
21-
| [systemInstruction?](./generative-ai.modelparams.systeminstruction.md) | | [Content](./generative-ai.content.md) | _(Optional)_ |
21+
| [systemInstruction?](./generative-ai.modelparams.systeminstruction.md) | | string \| [Part](./generative-ai.part.md) \| [Content](./generative-ai.content.md) | _(Optional)_ |
2222
| [toolConfig?](./generative-ai.modelparams.toolconfig.md) | | [ToolConfig](./generative-ai.toolconfig.md) | _(Optional)_ |
2323
| [tools?](./generative-ai.modelparams.tools.md) | | [Tool](./generative-ai.tool.md)<!-- -->\[\] | _(Optional)_ |
2424

‎docs/reference/main/generative-ai.modelparams.systeminstruction.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
**Signature:**
88

99
```typescript
10-
systemInstruction?: Content;
10+
systemInstruction?: string | Part | Content;
1111
```

‎docs/reference/main/generative-ai.startchatparams.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export interface StartChatParams extends BaseParams
1818
| Property | Modifiers | Type | Description |
1919
| --- | --- | --- | --- |
2020
| [history?](./generative-ai.startchatparams.history.md) | | [Content](./generative-ai.content.md)<!-- -->\[\] | _(Optional)_ |
21-
| [systemInstruction?](./generative-ai.startchatparams.systeminstruction.md) | | [Content](./generative-ai.content.md) | _(Optional)_ |
21+
| [systemInstruction?](./generative-ai.startchatparams.systeminstruction.md) | | string \| [Part](./generative-ai.part.md) \| [Content](./generative-ai.content.md) | _(Optional)_ |
2222
| [toolConfig?](./generative-ai.startchatparams.toolconfig.md) | | [ToolConfig](./generative-ai.toolconfig.md) | _(Optional)_ |
2323
| [tools?](./generative-ai.startchatparams.tools.md) | | [Tool](./generative-ai.tool.md)<!-- -->\[\] | _(Optional)_ |
2424

‎docs/reference/main/generative-ai.startchatparams.systeminstruction.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
**Signature:**
88

99
```typescript
10-
systemInstruction?: Content;
10+
systemInstruction?: string | Part | Content;
1111
```

‎packages/main/src/models/generative-model.test.ts

+50
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,31 @@ describe("GenerativeModel", () => {
101101
);
102102
restore();
103103
});
104+
it("passes text-only systemInstruction through to generateContent", async () => {
105+
const genModel = new GenerativeModel("apiKey", {
106+
model: "my-model",
107+
systemInstruction: "be friendly",
108+
});
109+
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
110+
const mockResponse = getMockResponse(
111+
"unary-success-basic-reply-short.json",
112+
);
113+
const makeRequestStub = stub(request, "makeRequest").resolves(
114+
mockResponse as Response,
115+
);
116+
await genModel.generateContent("hello");
117+
expect(makeRequestStub).to.be.calledWith(
118+
"models/my-model",
119+
request.Task.GENERATE_CONTENT,
120+
match.any,
121+
false,
122+
match((value: string) => {
123+
return value.includes("be friendly");
124+
}),
125+
match.any,
126+
);
127+
restore();
128+
});
104129
it("generateContent overrides model values", async () => {
105130
const genModel = new GenerativeModel("apiKey", {
106131
model: "my-model",
@@ -226,6 +251,31 @@ describe("GenerativeModel", () => {
226251
);
227252
restore();
228253
});
254+
it("passes params through to chat.sendMessage", async () => {
255+
const genModel = new GenerativeModel("apiKey", {
256+
model: "my-model",
257+
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
258+
});
259+
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
260+
const mockResponse = getMockResponse(
261+
"unary-success-basic-reply-short.json",
262+
);
263+
const makeRequestStub = stub(request, "makeRequest").resolves(
264+
mockResponse as Response,
265+
);
266+
await genModel.startChat().sendMessage("hello");
267+
expect(makeRequestStub).to.be.calledWith(
268+
"models/my-model",
269+
request.Task.GENERATE_CONTENT,
270+
match.any,
271+
false,
272+
match((value: string) => {
273+
return value.includes("be friendly");
274+
}),
275+
{},
276+
);
277+
restore();
278+
});
229279
it("startChat overrides model values", async () => {
230280
const genModel = new GenerativeModel("apiKey", {
231281
model: "my-model",

‎packages/main/src/models/generative-model.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import { batchEmbedContents, embedContent } from "../methods/embed-content";
4545
import {
4646
formatEmbedContentInput,
4747
formatGenerateContentInput,
48+
formatSystemInstruction,
4849
} from "../requests/request-helpers";
4950

5051
/**
@@ -76,7 +77,9 @@ export class GenerativeModel {
7677
this.safetySettings = modelParams.safetySettings || [];
7778
this.tools = modelParams.tools;
7879
this.toolConfig = modelParams.toolConfig;
79-
this.systemInstruction = modelParams.systemInstruction;
80+
this.systemInstruction = formatSystemInstruction(
81+
modelParams.systemInstruction,
82+
);
8083
this.requestOptions = requestOptions || {};
8184
}
8285

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/**
2+
* @license
3+
* Copyright 2024 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
import { expect, use } from "chai";
19+
import * as sinonChai from "sinon-chai";
20+
import { Content } from "../../types";
21+
import { formatGenerateContentInput } from "./request-helpers";
22+
23+
use(sinonChai);
24+
25+
describe("request formatting methods", () => {
26+
describe("formatGenerateContentInput", () => {
27+
it("formats a text string into a request", () => {
28+
const result = formatGenerateContentInput("some text content");
29+
expect(result).to.deep.equal({
30+
contents: [
31+
{
32+
role: "user",
33+
parts: [{ text: "some text content" }],
34+
},
35+
],
36+
});
37+
});
38+
it("formats an array of strings into a request", () => {
39+
const result = formatGenerateContentInput(["txt1", "txt2"]);
40+
expect(result).to.deep.equal({
41+
contents: [
42+
{
43+
role: "user",
44+
parts: [{ text: "txt1" }, { text: "txt2" }],
45+
},
46+
],
47+
});
48+
});
49+
it("formats an array of Parts into a request", () => {
50+
const result = formatGenerateContentInput([
51+
{ text: "txt1" },
52+
{ text: "txtB" },
53+
]);
54+
expect(result).to.deep.equal({
55+
contents: [
56+
{
57+
role: "user",
58+
parts: [{ text: "txt1" }, { text: "txtB" }],
59+
},
60+
],
61+
});
62+
});
63+
it("formats a mixed array into a request", () => {
64+
const result = formatGenerateContentInput(["txtA", { text: "txtB" }]);
65+
expect(result).to.deep.equal({
66+
contents: [
67+
{
68+
role: "user",
69+
parts: [{ text: "txtA" }, { text: "txtB" }],
70+
},
71+
],
72+
});
73+
});
74+
it("preserves other properties of request", () => {
75+
const result = formatGenerateContentInput({
76+
contents: [
77+
{
78+
role: "user",
79+
parts: [{ text: "txtA" }],
80+
},
81+
],
82+
generationConfig: { topK: 100 },
83+
});
84+
expect(result).to.deep.equal({
85+
contents: [
86+
{
87+
role: "user",
88+
parts: [{ text: "txtA" }],
89+
},
90+
],
91+
generationConfig: { topK: 100 },
92+
});
93+
});
94+
it("formats systemInstructions if provided as text", () => {
95+
const result = formatGenerateContentInput({
96+
contents: [
97+
{
98+
role: "user",
99+
parts: [{ text: "txtA" }],
100+
},
101+
],
102+
systemInstruction: "be excited",
103+
});
104+
expect(result).to.deep.equal({
105+
contents: [
106+
{
107+
role: "user",
108+
parts: [{ text: "txtA" }],
109+
},
110+
],
111+
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
112+
});
113+
});
114+
it("formats systemInstructions if provided as Part", () => {
115+
const result = formatGenerateContentInput({
116+
contents: [
117+
{
118+
role: "user",
119+
parts: [{ text: "txtA" }],
120+
},
121+
],
122+
systemInstruction: { text: "be excited" },
123+
});
124+
expect(result).to.deep.equal({
125+
contents: [
126+
{
127+
role: "user",
128+
parts: [{ text: "txtA" }],
129+
},
130+
],
131+
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
132+
});
133+
});
134+
it("formats systemInstructions if provided as Content (no role)", () => {
135+
const result = formatGenerateContentInput({
136+
contents: [
137+
{
138+
role: "user",
139+
parts: [{ text: "txtA" }],
140+
},
141+
],
142+
systemInstruction: { parts: [{ text: "be excited" }] } as Content,
143+
});
144+
expect(result).to.deep.equal({
145+
contents: [
146+
{
147+
role: "user",
148+
parts: [{ text: "txtA" }],
149+
},
150+
],
151+
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
152+
});
153+
});
154+
it("passes thru systemInstructions if provided as Content", () => {
155+
const result = formatGenerateContentInput({
156+
contents: [
157+
{
158+
role: "user",
159+
parts: [{ text: "txtA" }],
160+
},
161+
],
162+
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
163+
});
164+
expect(result).to.deep.equal({
165+
contents: [
166+
{
167+
role: "user",
168+
parts: [{ text: "txtA" }],
169+
},
170+
],
171+
systemInstruction: { role: "system", parts: [{ text: "be excited" }] },
172+
});
173+
});
174+
});
175+
});

‎packages/main/src/requests/request-helpers.ts

+29-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,25 @@ import {
2323
} from "../../types";
2424
import { GoogleGenerativeAIError } from "../errors";
2525

26+
export function formatSystemInstruction(
27+
input?: string | Part | Content,
28+
): Content | undefined {
29+
// null or undefined
30+
if (input == null) {
31+
return undefined;
32+
} else if (typeof input === "string") {
33+
return { role: "system", parts: [{ text: input }] } as Content;
34+
} else if ((input as Part).text) {
35+
return { role: "system", parts: [input as Part] };
36+
} else if ((input as Content).parts) {
37+
if (!(input as Content).role) {
38+
return { role: "system", parts: (input as Content).parts };
39+
} else {
40+
return input as Content;
41+
}
42+
}
43+
}
44+
2645
export function formatNewContent(
2746
request: string | Array<string | Part>,
2847
): Content {
@@ -88,12 +107,20 @@ function assignRoleToPartsAndValidateSendMessageRequest(
88107
export function formatGenerateContentInput(
89108
params: GenerateContentRequest | string | Array<string | Part>,
90109
): GenerateContentRequest {
110+
let formattedRequest: GenerateContentRequest;
91111
if ((params as GenerateContentRequest).contents) {
92-
return params as GenerateContentRequest;
112+
formattedRequest = params as GenerateContentRequest;
93113
} else {
114+
// Array or string
94115
const content = formatNewContent(params as string | Array<string | Part>);
95-
return { contents: [content] };
116+
formattedRequest = { contents: [content] };
117+
}
118+
if ((params as GenerateContentRequest).systemInstruction) {
119+
formattedRequest.systemInstruction = formatSystemInstruction(
120+
(params as GenerateContentRequest).systemInstruction,
121+
);
96122
}
123+
return formattedRequest;
97124
}
98125

99126
export function formatEmbedContentInput(

‎packages/main/types/requests.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the License.
1616
*/
1717

18-
import { Content } from "./content";
18+
import { Content, Part } from "./content";
1919
import {
2020
FunctionCallingMode,
2121
HarmBlockThreshold,
@@ -40,7 +40,7 @@ export interface ModelParams extends BaseParams {
4040
model: string;
4141
tools?: Tool[];
4242
toolConfig?: ToolConfig;
43-
systemInstruction?: Content;
43+
systemInstruction?: string | Part | Content;
4444
}
4545

4646
/**
@@ -51,7 +51,7 @@ export interface GenerateContentRequest extends BaseParams {
5151
contents: Content[];
5252
tools?: Tool[];
5353
toolConfig?: ToolConfig;
54-
systemInstruction?: Content;
54+
systemInstruction?: string | Part | Content;
5555
}
5656

5757
/**
@@ -84,7 +84,7 @@ export interface StartChatParams extends BaseParams {
8484
history?: Content[];
8585
tools?: Tool[];
8686
toolConfig?: ToolConfig;
87-
systemInstruction?: Content;
87+
systemInstruction?: string | Part | Content;
8888
}
8989

9090
/**

0 commit comments

Comments
 (0)
Please sign in to comment.