Skip to content

Commit edd6d3f

Browse files
authored
Fix schema input (response_model) for Google AI client (#687)
* Fix schema input (response_model) for Google AI client * acknowledgement
1 parent bc5a731 commit edd6d3f

File tree

4 files changed

+131
-19
lines changed

4 files changed

+131
-19
lines changed

Diff for: .changeset/short-banks-sit.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@browserbasehq/stagehand": patch
3+
---
4+
5+
Fixed the schema input for Gemini's response model

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ For more information, please see our [Contributing Guide](https://docs.stagehand
128128

129129
## Acknowledgements
130130

131-
This project heavily relies on [Playwright](https://playwright.dev/) as a resilient backbone to automate the web. It also would not be possible without the awesome techniques and discoveries made by [tarsier](https://github.com/reworkd/tarsier), and [fuji-web](https://github.com/normal-computing/fuji-web).
131+
This project heavily relies on [Playwright](https://playwright.dev/) as a resilient backbone to automate the web. It also would not be possible without the awesome techniques and discoveries made by [tarsier](https://github.com/reworkd/tarsier), [gemini-zod](https://github.com/jbeoris/gemini-zod), and [fuji-web](https://github.com/normal-computing/fuji-web).
132132

133133
We'd like to thank the following people for their major contributions to Stagehand:
134134
- [Paul Klein](https://github.com/pkiv)

Diff for: lib/llm/GoogleClient.ts

+4-18
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import zodToJsonSchema from "zod-to-json-schema";
1414
import { LogLine } from "../../types/log";
1515
import { AvailableModel, ClientOptions } from "../../types/model";
1616
import { LLMCache } from "../cache/LLMCache";
17-
import { validateZodSchema } from "../utils";
17+
import { validateZodSchema, toGeminiSchema } from "../utils";
1818
import {
1919
ChatCompletionOptions,
2020
ChatMessage,
@@ -290,25 +290,11 @@ export class GoogleClient extends LLMClient {
290290
temperature: temperature,
291291
topP: top_p,
292292
responseMimeType: response_model ? "application/json" : undefined,
293+
responseSchema: response_model
294+
? toGeminiSchema(response_model.schema)
295+
: undefined,
293296
};
294297

295-
// Handle JSON mode instructions
296-
if (response_model) {
297-
// Prepend instructions for JSON output if needed (similar to o1 handling)
298-
const schemaString = JSON.stringify(
299-
zodToJsonSchema(response_model.schema),
300-
);
301-
formattedMessages.push({
302-
role: "user",
303-
parts: [
304-
{
305-
text: `Please respond ONLY with a valid JSON object that strictly adheres to the following JSON schema. Do not include any other text, explanations, or markdown formatting like \`\`\`json ... \`\`\`. Just the JSON object.\n\nSchema:\n${schemaString}`,
306-
},
307-
],
308-
});
309-
formattedMessages.push({ role: "model", parts: [{ text: "{" }] }); // Prime the model
310-
}
311-
312298
logger({
313299
category: "google",
314300
message: "creating chat completion",

Diff for: lib/utils.ts

+121
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { z } from "zod";
33
import { ObserveResult, Page } from ".";
44
import { LogLine } from "../types/log";
55
import { TextAnnotation } from "../types/textannotation";
6+
import { Schema, Type } from "@google/genai";
67

78
// This is a heuristic for the width of a character in pixels. It seems to work
89
// better than attempting to calculate character widths dynamically, which sometimes
@@ -442,3 +443,123 @@ export function isRunningInBun(): boolean {
442443
"bun" in process.versions
443444
);
444445
}
446+
447+
/*
448+
* Helper functions for converting between Gemini and Zod schemas
449+
*/
450+
function decorateGeminiSchema(
451+
geminiSchema: Schema,
452+
zodSchema: z.ZodTypeAny,
453+
): Schema {
454+
if (geminiSchema.nullable === undefined) {
455+
geminiSchema.nullable = zodSchema.isOptional();
456+
}
457+
458+
if (zodSchema.description) {
459+
geminiSchema.description = zodSchema.description;
460+
}
461+
462+
return geminiSchema;
463+
}
464+
465+
export function toGeminiSchema(zodSchema: z.ZodTypeAny): Schema {
466+
const zodType = getZodType(zodSchema);
467+
468+
switch (zodType) {
469+
case "ZodArray": {
470+
return decorateGeminiSchema(
471+
{
472+
type: Type.ARRAY,
473+
items: toGeminiSchema(
474+
(zodSchema as z.ZodArray<z.ZodTypeAny>).element,
475+
),
476+
},
477+
zodSchema,
478+
);
479+
}
480+
case "ZodObject": {
481+
const properties: Record<string, Schema> = {};
482+
const required: string[] = [];
483+
484+
Object.entries((zodSchema as z.ZodObject<z.ZodRawShape>).shape).forEach(
485+
([key, value]: [string, z.ZodTypeAny]) => {
486+
properties[key] = toGeminiSchema(value);
487+
if (getZodType(value) !== "ZodOptional") {
488+
required.push(key);
489+
}
490+
},
491+
);
492+
493+
return decorateGeminiSchema(
494+
{
495+
type: Type.OBJECT,
496+
properties,
497+
required: required.length > 0 ? required : undefined,
498+
},
499+
zodSchema,
500+
);
501+
}
502+
case "ZodString":
503+
return decorateGeminiSchema(
504+
{
505+
type: Type.STRING,
506+
},
507+
zodSchema,
508+
);
509+
case "ZodNumber":
510+
return decorateGeminiSchema(
511+
{
512+
type: Type.NUMBER,
513+
},
514+
zodSchema,
515+
);
516+
case "ZodBoolean":
517+
return decorateGeminiSchema(
518+
{
519+
type: Type.BOOLEAN,
520+
},
521+
zodSchema,
522+
);
523+
case "ZodEnum":
524+
return decorateGeminiSchema(
525+
{
526+
type: Type.STRING,
527+
enum: zodSchema._def.values,
528+
},
529+
zodSchema,
530+
);
531+
case "ZodDefault":
532+
case "ZodNullable":
533+
case "ZodOptional": {
534+
const innerSchema = toGeminiSchema(zodSchema._def.innerType);
535+
return decorateGeminiSchema(
536+
{
537+
...innerSchema,
538+
nullable: true,
539+
},
540+
zodSchema,
541+
);
542+
}
543+
case "ZodLiteral":
544+
return decorateGeminiSchema(
545+
{
546+
type: Type.STRING,
547+
enum: [zodSchema._def.value],
548+
},
549+
zodSchema,
550+
);
551+
default:
552+
return decorateGeminiSchema(
553+
{
554+
type: Type.OBJECT,
555+
nullable: true,
556+
},
557+
zodSchema,
558+
);
559+
}
560+
}
561+
562+
// Helper function to check the type of Zod schema
563+
export function getZodType(schema: z.ZodTypeAny): string {
564+
return schema._def.typeName;
565+
}

0 commit comments

Comments
 (0)