Skip to content

Commit 83ec4ac

Browse files
authoredMay 31, 2024··
Add GenerateContentRequest as an optional param to CountTokensRequest (#148)
Expand the model's countTokens method to alternatively accept a GenerateContentRequest. Added integration and unit tests.
1 parent 9318ee8 commit 83ec4ac

10 files changed

+132
-8
lines changed
 

‎.changeset/dirty-wolves-sin.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@google/generative-ai": minor
3+
---
4+
5+
Expand the model's `countTokens` method to alternatively accept a `GenerateContentRequest`.

‎docs/reference/main/generative-ai.counttokensrequest.contents.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
**Signature:**
88

99
```typescript
10-
contents: Content[];
10+
contents?: Content[];
1111
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<!-- Do not edit this file. It is automatically generated by API Documenter. -->
2+
3+
[Home](./index.md) &gt; [@google/generative-ai](./generative-ai.md) &gt; [CountTokensRequest](./generative-ai.counttokensrequest.md) &gt; [generateContentRequest](./generative-ai.counttokensrequest.generatecontentrequest.md)
4+
5+
## CountTokensRequest.generateContentRequest property
6+
7+
**Signature:**
8+
9+
```typescript
10+
generateContentRequest?: GenerateContentRequest;
11+
```

‎docs/reference/main/generative-ai.counttokensrequest.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
## CountTokensRequest interface
66

7-
Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md)
7+
Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md)<!-- -->.
8+
9+
The request must contain either a [Content](./generative-ai.content.md) array or a [GenerateContentRequest](./generative-ai.generatecontentrequest.md)<!-- -->, but not both. If both are provided then a [GoogleGenerativeAIRequestInputError](./generative-ai.googlegenerativeairequestinputerror.md) is thrown.
810

911
**Signature:**
1012

@@ -16,5 +18,6 @@ export interface CountTokensRequest
1618

1719
| Property | Modifiers | Type | Description |
1820
| --- | --- | --- | --- |
19-
| [contents](./generative-ai.counttokensrequest.contents.md) | | [Content](./generative-ai.content.md)<!-- -->\[\] | |
21+
| [contents?](./generative-ai.counttokensrequest.contents.md) | | [Content](./generative-ai.content.md)<!-- -->\[\] | _(Optional)_ |
22+
| [generateContentRequest?](./generative-ai.counttokensrequest.generatecontentrequest.md) | | [GenerateContentRequest](./generative-ai.generatecontentrequest.md) | _(Optional)_ |
2023

‎docs/reference/main/generative-ai.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
| [CitationSource](./generative-ai.citationsource.md) | A single citation source. |
4141
| [Content](./generative-ai.content.md) | Content type for both prompts and response candidates. |
4242
| [ContentEmbedding](./generative-ai.contentembedding.md) | A single content embedding. |
43-
| [CountTokensRequest](./generative-ai.counttokensrequest.md) | Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md) |
43+
| [CountTokensRequest](./generative-ai.counttokensrequest.md) | <p>Params for calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md)<!-- -->.</p><p>The request must contain either a [Content](./generative-ai.content.md) array or a [GenerateContentRequest](./generative-ai.generatecontentrequest.md)<!-- -->, but not both. If both are provided then a [GoogleGenerativeAIRequestInputError](./generative-ai.googlegenerativeairequestinputerror.md) is thrown.</p> |
4444
| [CountTokensResponse](./generative-ai.counttokensresponse.md) | Response from calling [GenerativeModel.countTokens()](./generative-ai.generativemodel.counttokens.md)<!-- -->. |
4545
| [EmbedContentRequest](./generative-ai.embedcontentrequest.md) | Params for calling [GenerativeModel.embedContent()](./generative-ai.generativemodel.embedcontent.md) |
4646
| [EmbedContentResponse](./generative-ai.embedcontentresponse.md) | Response from calling [GenerativeModel.embedContent()](./generative-ai.generativemodel.embedcontent.md)<!-- -->. |

‎packages/main/src/models/generative-model.test.ts

+31
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import { expect, use } from "chai";
1818
import { GenerativeModel } from "./generative-model";
1919
import * as sinonChai from "sinon-chai";
2020
import {
21+
CountTokensRequest,
2122
FunctionCallingMode,
2223
FunctionDeclarationSchemaType,
2324
HarmBlockThreshold,
@@ -409,4 +410,34 @@ describe("GenerativeModel", () => {
409410
);
410411
restore();
411412
});
413+
it("countTokens errors if contents and generateContentRequest are both defined", async () => {
414+
const genModel = new GenerativeModel(
415+
"apiKey",
416+
{
417+
model: "my-model",
418+
},
419+
{
420+
apiVersion: "v2000",
421+
},
422+
);
423+
const mockResponse = getMockResponse(
424+
"unary-success-basic-reply-short.json",
425+
);
426+
const makeRequestStub = stub(request, "makeRequest").resolves(
427+
mockResponse as Response,
428+
);
429+
const countTokensRequest: CountTokensRequest = {
430+
contents: [{ role: "user", parts: [{ text: "hello" }] }],
431+
generateContentRequest: {
432+
contents: [{ role: "user", parts: [{ text: "hello" }] }],
433+
},
434+
};
435+
await expect(
436+
genModel.countTokens(countTokensRequest),
437+
).to.eventually.be.rejectedWith(
438+
"CountTokensRequest must have one of contents or generateContentRequest, not both.",
439+
);
440+
expect(makeRequestStub).to.not.be.called;
441+
restore();
442+
});
412443
});

‎packages/main/src/models/generative-model.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import { ChatSession } from "../methods/chat-session";
4343
import { countTokens } from "../methods/count-tokens";
4444
import { batchEmbedContents, embedContent } from "../methods/embed-content";
4545
import {
46+
formatCountTokensInput,
4647
formatEmbedContentInput,
4748
formatGenerateContentInput,
4849
formatSystemInstruction,
@@ -157,7 +158,7 @@ export class GenerativeModel {
157158
async countTokens(
158159
request: CountTokensRequest | string | Array<string | Part>,
159160
): Promise<CountTokensResponse> {
160-
const formattedParams = formatGenerateContentInput(request);
161+
const formattedParams = formatCountTokensInput(request, this.model);
161162
return countTokens(
162163
this.apiKey,
163164
this.model,

‎packages/main/src/requests/request-helpers.ts

+31-1
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@
1717

1818
import {
1919
Content,
20+
CountTokensRequest,
21+
CountTokensRequestInternal,
2022
EmbedContentRequest,
2123
GenerateContentRequest,
2224
Part,
2325
} from "../../types";
24-
import { GoogleGenerativeAIError } from "../errors";
26+
import {
27+
GoogleGenerativeAIError,
28+
GoogleGenerativeAIRequestInputError,
29+
} from "../errors";
2530

2631
export function formatSystemInstruction(
2732
input?: string | Part | Content,
@@ -104,6 +109,31 @@ function assignRoleToPartsAndValidateSendMessageRequest(
104109
return functionContent;
105110
}
106111

112+
export function formatCountTokensInput(
113+
params: CountTokensRequest | string | Array<string | Part>,
114+
model: string,
115+
): CountTokensRequestInternal {
116+
let formattedRequest: CountTokensRequestInternal = {};
117+
const containsGenerateContentRequest =
118+
(params as CountTokensRequest).generateContentRequest != null;
119+
if ((params as CountTokensRequest).contents) {
120+
if (containsGenerateContentRequest) {
121+
throw new GoogleGenerativeAIRequestInputError(
122+
"CountTokensRequest must have one of contents or generateContentRequest, not both.",
123+
);
124+
}
125+
formattedRequest = { ...(params as CountTokensRequest) };
126+
} else if (containsGenerateContentRequest) {
127+
formattedRequest = { ...(params as CountTokensRequest) };
128+
formattedRequest.generateContentRequest.model = model;
129+
} else {
130+
// Array or string
131+
const content = formatNewContent(params as string | Array<string | Part>);
132+
formattedRequest.contents = [content];
133+
}
134+
return formattedRequest;
135+
}
136+
107137
export function formatGenerateContentInput(
108138
params: GenerateContentRequest | string | Array<string | Part>,
109139
): GenerateContentRequest {

‎packages/main/test-integration/node/count-tokens.test.ts

+20
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import { expect, use } from "chai";
1919
import * as chaiAsPromised from "chai-as-promised";
2020
import { GoogleGenerativeAI, HarmBlockThreshold, HarmCategory } from "../..";
21+
import { CountTokensRequest } from "../../types";
2122

2223
use(chaiAsPromised);
2324

@@ -46,4 +47,23 @@ describe("countTokens", function () {
4647
expect(response1.totalTokens).to.equal(3);
4748
expect(response2.totalTokens).to.equal(3);
4849
});
50+
it("counts tokens with GenerateContentRequest", async () => {
51+
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY || "");
52+
const model = genAI.getGenerativeModel({
53+
model: "gemini-1.5-flash-latest",
54+
safetySettings: [
55+
{
56+
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
57+
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH,
58+
},
59+
],
60+
});
61+
const countTokensRequest: CountTokensRequest = {
62+
generateContentRequest: {
63+
contents: [{ role: "user", parts: [{ text: "count me" }] }],
64+
},
65+
};
66+
const response = await model.countTokens(countTokensRequest);
67+
expect(response.totalTokens).to.equal(3);
68+
});
4969
});

‎packages/main/types/requests.ts

+25-2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ export interface GenerateContentRequest extends BaseParams {
5454
systemInstruction?: string | Part | Content;
5555
}
5656

57+
/**
58+
* Request sent to `generateContent` endpoint.
59+
* @internal
60+
*/
61+
export interface GenerateContentRequestInternal extends GenerateContentRequest {
62+
model?: string;
63+
}
64+
5765
/**
5866
* Safety setting that can be sent as part of request parameters.
5967
* @public
@@ -101,11 +109,26 @@ export interface StartChatParams extends BaseParams {
101109
}
102110

103111
/**
104-
* Params for calling {@link GenerativeModel.countTokens}
112+
* Params for calling {@link GenerativeModel.countTokens}.
113+
*
114+
* The request must contain either a {@link Content} array or a
115+
* {@link GenerateContentRequest}, but not both. If both are provided
116+
* then a {@link GoogleGenerativeAIRequestInputError} is thrown.
117+
*
105118
* @public
106119
*/
107120
export interface CountTokensRequest {
108-
contents: Content[];
121+
generateContentRequest?: GenerateContentRequest;
122+
contents?: Content[];
123+
}
124+
125+
/**
126+
* Params for calling {@link GenerativeModel.countTokens}
127+
* @internal
128+
*/
129+
export interface CountTokensRequestInternal {
130+
generateContentRequest?: GenerateContentRequestInternal;
131+
contents?: Content[];
109132
}
110133

111134
/**

0 commit comments

Comments
 (0)
Please sign in to comment.