Skip to content

fix thinking for gemini models #1113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions src/providers/google-vertex-ai/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ export const VertexAnthropicChatCompleteConfig: ProviderConfig = {
model: {
param: 'model',
required: false,
transform: (params: Params) => {
transform: () => {
return undefined;
},
},
Expand Down Expand Up @@ -438,6 +438,7 @@ export const GoogleChatCompleteResponseTransform: (
promptTokenCount = 0,
candidatesTokenCount = 0,
totalTokenCount = 0,
thoughtsTokenCount = 0,
} = response.usageMetadata;

return {
Expand All @@ -449,8 +450,9 @@ export const GoogleChatCompleteResponseTransform: (
choices:
response.candidates?.map((generation, index) => {
// transform tool calls and content by iterating over the content parts
let toolCalls: ToolCall[] = [];
const toolCalls: ToolCall[] = [];
let content: string | undefined;
const contentBlocks = [];
for (const part of generation.content?.parts ?? []) {
if (part.functionCall) {
toolCalls.push({
Expand All @@ -462,12 +464,11 @@ export const GoogleChatCompleteResponseTransform: (
},
});
} else if (part.text) {
// if content is already set to the chain of thought message and the user requires both the CoT message and the completion, we need to append the completion to the CoT message
if (content?.length && !strictOpenAiCompliance) {
content += '\r\n\r\n' + part.text;
if (part.thought) {
contentBlocks.push({ type: 'thinking', thinking: part.text });
} else {
// if content is already set to CoT, but user requires only the completion, we need to set content to the completion
content = part.text;
contentBlocks.push({ type: 'text', text: part.text });
}
}
}
Expand All @@ -476,6 +477,8 @@ export const GoogleChatCompleteResponseTransform: (
role: MESSAGE_ROLES.ASSISTANT,
...(toolCalls.length && { tool_calls: toolCalls }),
...(content && { content }),
...(!strictOpenAiCompliance &&
contentBlocks.length && { content_blocks: contentBlocks }),
};
const logprobsContent: Logprobs[] | null =
transformVertexLogprobs(generation);
Expand Down Expand Up @@ -503,6 +506,9 @@ export const GoogleChatCompleteResponseTransform: (
prompt_tokens: promptTokenCount,
completion_tokens: candidatesTokenCount,
total_tokens: totalTokenCount,
completion_tokens_details: {
reasoning_tokens: thoughtsTokenCount,
},
},
};
}
Expand Down Expand Up @@ -593,6 +599,9 @@ export const GoogleChatCompleteStreamChunkTransform: (
prompt_tokens: parsedChunk.usageMetadata.promptTokenCount,
completion_tokens: parsedChunk.usageMetadata.candidatesTokenCount,
total_tokens: parsedChunk.usageMetadata.totalTokenCount,
completion_tokens_details: {
reasoning_tokens: parsedChunk.usageMetadata.thoughtsTokenCount ?? 0,
},
};
}

Expand All @@ -604,32 +613,30 @@ export const GoogleChatCompleteStreamChunkTransform: (
provider: GOOGLE_VERTEX_AI,
choices:
parsedChunk.candidates?.map((generation, index) => {
let message: Message = { role: 'assistant', content: '' };
let message: any = { role: 'assistant', content: '' };
if (generation.content?.parts[0]?.text) {
if (generation.content.parts[0].thought)
streamState.containsChainOfThoughtMessage = true;

let content: string =
strictOpenAiCompliance && streamState.containsChainOfThoughtMessage
? ''
: generation.content.parts[0]?.text;
if (generation.content.parts[1]?.text) {
if (strictOpenAiCompliance)
content = generation.content.parts[1].text;
else content += '\r\n\r\n' + generation.content.parts[1]?.text;
streamState.containsChainOfThoughtMessage = false;
} else if (
streamState.containsChainOfThoughtMessage &&
!generation.content.parts[0]?.thought
) {
if (strictOpenAiCompliance)
content = generation.content.parts[0].text;
else content = '\r\n\r\n' + content;
streamState.containsChainOfThoughtMessage = false;
const contentBlocks = [];
let content = '';
for (const part of generation.content.parts) {
if (part.thought) {
contentBlocks.push({
index: 0,
delta: { thinking: part.text },
});
streamState.containsChainOfThoughtMessage = true;
} else {
content = part.text ?? '';
contentBlocks.push({
index: streamState.containsChainOfThoughtMessage ? 1 : 0,
delta: { text: part.text },
});
}
}
message = {
role: 'assistant',
content,
...(!strictOpenAiCompliance &&
contentBlocks.length && { content_blocks: contentBlocks }),
};
} else if (generation.content?.parts[0]?.functionCall) {
message = {
Expand Down Expand Up @@ -706,7 +713,7 @@ export const VertexAnthropicChatCompleteResponseTransform: (
}

if ('content' in response) {
const { input_tokens = 0, output_tokens = 0 } = response?.usage;
const { input_tokens = 0, output_tokens = 0 } = response?.usage ?? {};

let content: AnthropicContentItem[] | string = strictOpenAiCompliance
? ''
Expand Down
6 changes: 4 additions & 2 deletions src/providers/google-vertex-ai/transformGenerationConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ export function transformGenerationConfig(params: Params) {
}

if (params?.thinking) {
const { budget_tokens, type } = params.thinking;
const thinkingConfig: Record<string, any> = {};
thinkingConfig['include_thoughts'] = true;
thinkingConfig['thinking_budget'] = params.thinking.budget_tokens;
thinkingConfig['include_thoughts'] =
type === 'enabled' && budget_tokens ? true : false;
thinkingConfig['thinking_budget'] = budget_tokens;
generationConfig['thinking_config'] = thinkingConfig;
}

Expand Down
1 change: 1 addition & 0 deletions src/providers/google-vertex-ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ export interface GoogleGenerateContentResponse {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
thoughtsTokenCount?: number;
};
}

Expand Down
66 changes: 37 additions & 29 deletions src/providers/google/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ const transformGenerationConfig = (params: Params) => {
}
if (params?.thinking) {
const thinkingConfig: Record<string, any> = {};
thinkingConfig['include_thoughts'] = true;
const { budget_tokens, type } = params.thinking;
thinkingConfig['include_thoughts'] =
type === 'enabled' && budget_tokens ? true : false;
thinkingConfig['thinking_budget'] = params.thinking.budget_tokens;
generationConfig['thinking_config'] = thinkingConfig;
}
Expand Down Expand Up @@ -486,6 +488,7 @@ interface GoogleGenerateContentResponse {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
thoughtsTokenCount?: number;
};
}

Expand Down Expand Up @@ -536,8 +539,9 @@ export const GoogleChatCompleteResponseTransform: (
choices:
response.candidates?.map((generation, idx) => {
// transform tool calls and content by iterating over the content parts
let toolCalls: ToolCall[] = [];
const toolCalls: ToolCall[] = [];
let content: string | undefined;
const contentBlocks = [];
for (const part of generation.content?.parts ?? []) {
if (part.functionCall) {
toolCalls.push({
Expand All @@ -549,12 +553,11 @@ export const GoogleChatCompleteResponseTransform: (
},
});
} else if (part.text) {
// if content is already set to the chain of thought message and the user requires both the CoT message and the completion, we need to append the completion to the CoT message
if (content?.length && !strictOpenAiCompliance) {
content += '\r\n\r\n' + part.text;
if (part.thought) {
contentBlocks.push({ type: 'thinking', thinking: part.text });
} else {
// if content is already set to CoT, but user requires only the completion, we need to set content to the completion
content = part.text;
contentBlocks.push({ type: 'text', text: part.text });
}
}
}
Expand All @@ -563,6 +566,8 @@ export const GoogleChatCompleteResponseTransform: (
role: MESSAGE_ROLES.ASSISTANT,
...(toolCalls.length && { tool_calls: toolCalls }),
...(content && { content }),
...(!strictOpenAiCompliance &&
contentBlocks.length && { content_blocks: contentBlocks }),
};
const logprobsContent: Logprobs[] | null =
transformVertexLogprobs(generation);
Expand All @@ -586,6 +591,9 @@ export const GoogleChatCompleteResponseTransform: (
prompt_tokens: response.usageMetadata.promptTokenCount,
completion_tokens: response.usageMetadata.candidatesTokenCount,
total_tokens: response.usageMetadata.totalTokenCount,
completion_tokens_details: {
reasoning_tokens: response.usageMetadata.thoughtsTokenCount ?? 0,
},
},
};
}
Expand Down Expand Up @@ -631,6 +639,9 @@ export const GoogleChatCompleteStreamChunkTransform: (
prompt_tokens: parsedChunk.usageMetadata.promptTokenCount,
completion_tokens: parsedChunk.usageMetadata.candidatesTokenCount,
total_tokens: parsedChunk.usageMetadata.totalTokenCount,
completion_tokens_details: {
reasoning_tokens: parsedChunk.usageMetadata.thoughtsTokenCount ?? 0,
},
};
}

Expand All @@ -643,35 +654,32 @@ export const GoogleChatCompleteStreamChunkTransform: (
provider: 'google',
choices:
parsedChunk.candidates?.map((generation, index) => {
let message: Message = { role: 'assistant', content: '' };
let message: any = { role: 'assistant', content: '' };
if (generation.content?.parts[0]?.text) {
if (generation.content.parts[0].thought)
streamState.containsChainOfThoughtMessage = true;

let content: string =
strictOpenAiCompliance &&
streamState.containsChainOfThoughtMessage
? ''
: generation.content.parts[0]?.text;
if (generation.content.parts[1]?.text) {
if (strictOpenAiCompliance)
content = generation.content.parts[1].text;
else content += '\r\n\r\n' + generation.content.parts[1]?.text;
streamState.containsChainOfThoughtMessage = false;
} else if (
streamState.containsChainOfThoughtMessage &&
!generation.content.parts[0]?.thought
) {
if (strictOpenAiCompliance)
content = generation.content.parts[0].text;
else content = '\r\n\r\n' + content;
streamState.containsChainOfThoughtMessage = false;
const contentBlocks = [];
let content = '';
for (const part of generation.content.parts) {
if (part.thought) {
contentBlocks.push({
index: 0,
delta: { thinking: part.text },
});
streamState.containsChainOfThoughtMessage = true;
} else {
content = part.text ?? '';
contentBlocks.push({
index: streamState.containsChainOfThoughtMessage ? 1 : 0,
delta: { text: part.text },
});
}
}
message = {
role: 'assistant',
content,
...(!strictOpenAiCompliance &&
contentBlocks.length && { content_blocks: contentBlocks }),
};
} else if (generation.content.parts[0]?.functionCall) {
} else if (generation.content?.parts[0]?.functionCall) {
message = {
role: 'assistant',
tool_calls: generation.content.parts.map((part, idx) => {
Expand Down