Skip to content

Commit 08907eb

Browse files
kamatharihanv
andauthored
allow llmClient to be optionally passed in (#352) (#364)
* allow llmClient to be optionally passed in (#352) * feat: allow llmClient to be optionally passed in * update: add ollama client example from pr: #349 * update: README and changeset * lint --------- Co-authored-by: Arihan Varanasi <[email protected]>
1 parent 89841fc commit 08907eb

File tree

8 files changed

+378
-5
lines changed

8 files changed

+378
-5
lines changed

Diff for: .changeset/spicy-singers-flow.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@browserbasehq/stagehand": minor
3+
---
4+
5+
exposed llmClient in stagehand constructor

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ This constructor is used to create an instance of Stagehand.
150150
- `1`: SDK-level logging
151151
- `2`: LLM-client level logging (most granular)
152152
- `debugDom`: a `boolean` that draws bounding boxes around elements presented to the LLM during automation.
153+
- `llmClient`: (optional) a custom `LLMClient` implementation.
153154

154155
- **Returns:**
155156

Diff for: examples/external_client.ts

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import { type ConstructorParams, type LogLine, Stagehand } from "../lib";
2+
import { z } from "zod";
3+
import { OllamaClient } from "./external_clients/ollama";
4+
5+
const StagehandConfig: ConstructorParams = {
6+
env: "BROWSERBASE",
7+
apiKey: process.env.BROWSERBASE_API_KEY,
8+
projectId: process.env.BROWSERBASE_PROJECT_ID,
9+
verbose: 1,
10+
llmClient: new OllamaClient(
11+
(message: LogLine) =>
12+
console.log(`[stagehand::${message.category}] ${message.message}`),
13+
false,
14+
undefined,
15+
"llama3.2",
16+
),
17+
debugDom: true,
18+
};
19+
20+
async function example() {
21+
const stagehand = new Stagehand(StagehandConfig);
22+
23+
await stagehand.init();
24+
await stagehand.page.goto("https://news.ycombinator.com");
25+
26+
const headlines = await stagehand.page.extract({
27+
instruction: "Extract only 3 stories from the Hacker News homepage.",
28+
schema: z.object({
29+
stories: z
30+
.array(
31+
z.object({
32+
title: z.string(),
33+
url: z.string(),
34+
points: z.number(),
35+
}),
36+
)
37+
.length(3),
38+
}),
39+
});
40+
41+
console.log(headlines);
42+
43+
await stagehand.close();
44+
}
45+
46+
(async () => {
47+
await example();
48+
})();

Diff for: examples/external_clients/ollama.ts

+313
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
import OpenAI, { type ClientOptions } from "openai";
2+
import { zodResponseFormat } from "openai/helpers/zod";
3+
import type { LLMCache } from "../../lib/cache/LLMCache";
4+
import { validateZodSchema } from "../../lib/utils";
5+
import {
6+
type ChatCompletionOptions,
7+
type ChatMessage,
8+
LLMClient,
9+
} from "../../lib/llm/LLMClient";
10+
import type { LogLine } from "../../types/log";
11+
import type { AvailableModel } from "../../types/model";
12+
import type {
13+
ChatCompletion,
14+
ChatCompletionAssistantMessageParam,
15+
ChatCompletionContentPartImage,
16+
ChatCompletionContentPartText,
17+
ChatCompletionCreateParamsNonStreaming,
18+
ChatCompletionMessageParam,
19+
ChatCompletionSystemMessageParam,
20+
ChatCompletionUserMessageParam,
21+
} from "openai/resources/chat";
22+
23+
export class OllamaClient extends LLMClient {
24+
public type = "ollama" as const;
25+
private client: OpenAI;
26+
private cache: LLMCache | undefined;
27+
public logger: (message: LogLine) => void;
28+
private enableCaching: boolean;
29+
public clientOptions: ClientOptions;
30+
31+
constructor(
32+
logger: (message: LogLine) => void,
33+
enableCaching = false,
34+
cache: LLMCache | undefined,
35+
modelName: "llama3.2",
36+
clientOptions?: ClientOptions,
37+
) {
38+
super(modelName as AvailableModel);
39+
this.client = new OpenAI({
40+
...clientOptions,
41+
baseURL: clientOptions?.baseURL || "http://localhost:11434/v1",
42+
apiKey: "ollama",
43+
});
44+
this.logger = logger;
45+
this.cache = cache;
46+
this.enableCaching = enableCaching;
47+
this.modelName = modelName as AvailableModel;
48+
}
49+
50+
async createChatCompletion<T = ChatCompletion>(
51+
options: ChatCompletionOptions,
52+
retries = 3,
53+
): Promise<T> {
54+
const { image, requestId, ...optionsWithoutImageAndRequestId } = options;
55+
56+
// TODO: Implement vision support
57+
if (image) {
58+
throw new Error(
59+
"Image provided. Vision is not currently supported for Ollama",
60+
);
61+
}
62+
63+
this.logger({
64+
category: "ollama",
65+
message: "creating chat completion",
66+
level: 1,
67+
auxiliary: {
68+
options: {
69+
value: JSON.stringify({
70+
...optionsWithoutImageAndRequestId,
71+
requestId,
72+
}),
73+
type: "object",
74+
},
75+
modelName: {
76+
value: this.modelName,
77+
type: "string",
78+
},
79+
},
80+
});
81+
82+
const cacheOptions = {
83+
model: this.modelName,
84+
messages: options.messages,
85+
temperature: options.temperature,
86+
top_p: options.top_p,
87+
frequency_penalty: options.frequency_penalty,
88+
presence_penalty: options.presence_penalty,
89+
image: image,
90+
response_model: options.response_model,
91+
};
92+
93+
if (options.image) {
94+
const screenshotMessage: ChatMessage = {
95+
role: "user",
96+
content: [
97+
{
98+
type: "image_url",
99+
image_url: {
100+
url: `data:image/jpeg;base64,${options.image.buffer.toString("base64")}`,
101+
},
102+
},
103+
...(options.image.description
104+
? [{ type: "text", text: options.image.description }]
105+
: []),
106+
],
107+
};
108+
109+
options.messages.push(screenshotMessage);
110+
}
111+
112+
if (this.enableCaching && this.cache) {
113+
const cachedResponse = await this.cache.get<T>(
114+
cacheOptions,
115+
options.requestId,
116+
);
117+
118+
if (cachedResponse) {
119+
this.logger({
120+
category: "llm_cache",
121+
message: "LLM cache hit - returning cached response",
122+
level: 1,
123+
auxiliary: {
124+
requestId: {
125+
value: options.requestId,
126+
type: "string",
127+
},
128+
cachedResponse: {
129+
value: JSON.stringify(cachedResponse),
130+
type: "object",
131+
},
132+
},
133+
});
134+
return cachedResponse;
135+
}
136+
137+
this.logger({
138+
category: "llm_cache",
139+
message: "LLM cache miss - no cached response found",
140+
level: 1,
141+
auxiliary: {
142+
requestId: {
143+
value: options.requestId,
144+
type: "string",
145+
},
146+
},
147+
});
148+
}
149+
150+
let responseFormat = undefined;
151+
if (options.response_model) {
152+
responseFormat = zodResponseFormat(
153+
options.response_model.schema,
154+
options.response_model.name,
155+
);
156+
}
157+
158+
/* eslint-disable */
159+
// Remove unsupported options
160+
const { response_model, ...ollamaOptions } = {
161+
...optionsWithoutImageAndRequestId,
162+
model: this.modelName,
163+
};
164+
165+
this.logger({
166+
category: "ollama",
167+
message: "creating chat completion",
168+
level: 1,
169+
auxiliary: {
170+
ollamaOptions: {
171+
value: JSON.stringify(ollamaOptions),
172+
type: "object",
173+
},
174+
},
175+
});
176+
177+
const formattedMessages: ChatCompletionMessageParam[] =
178+
options.messages.map((message) => {
179+
if (Array.isArray(message.content)) {
180+
const contentParts = message.content.map((content) => {
181+
if ("image_url" in content) {
182+
const imageContent: ChatCompletionContentPartImage = {
183+
image_url: {
184+
url: content.image_url.url,
185+
},
186+
type: "image_url",
187+
};
188+
return imageContent;
189+
} else {
190+
const textContent: ChatCompletionContentPartText = {
191+
text: content.text,
192+
type: "text",
193+
};
194+
return textContent;
195+
}
196+
});
197+
198+
if (message.role === "system") {
199+
const formattedMessage: ChatCompletionSystemMessageParam = {
200+
...message,
201+
role: "system",
202+
content: contentParts.filter(
203+
(content): content is ChatCompletionContentPartText =>
204+
content.type === "text",
205+
),
206+
};
207+
return formattedMessage;
208+
} else if (message.role === "user") {
209+
const formattedMessage: ChatCompletionUserMessageParam = {
210+
...message,
211+
role: "user",
212+
content: contentParts,
213+
};
214+
return formattedMessage;
215+
} else {
216+
const formattedMessage: ChatCompletionAssistantMessageParam = {
217+
...message,
218+
role: "assistant",
219+
content: contentParts.filter(
220+
(content): content is ChatCompletionContentPartText =>
221+
content.type === "text",
222+
),
223+
};
224+
return formattedMessage;
225+
}
226+
}
227+
228+
const formattedMessage: ChatCompletionUserMessageParam = {
229+
role: "user",
230+
content: message.content,
231+
};
232+
233+
return formattedMessage;
234+
});
235+
236+
const body: ChatCompletionCreateParamsNonStreaming = {
237+
...ollamaOptions,
238+
model: this.modelName,
239+
messages: formattedMessages,
240+
response_format: responseFormat,
241+
stream: false,
242+
tools: options.tools?.filter((tool) => "function" in tool), // ensure only OpenAI compatibletools are used
243+
};
244+
245+
const response = await this.client.chat.completions.create(body);
246+
247+
this.logger({
248+
category: "ollama",
249+
message: "response",
250+
level: 1,
251+
auxiliary: {
252+
response: {
253+
value: JSON.stringify(response),
254+
type: "object",
255+
},
256+
requestId: {
257+
value: requestId,
258+
type: "string",
259+
},
260+
},
261+
});
262+
263+
if (options.response_model) {
264+
const extractedData = response.choices[0].message.content;
265+
const parsedData = JSON.parse(extractedData);
266+
267+
if (!validateZodSchema(options.response_model.schema, parsedData)) {
268+
if (retries > 0) {
269+
return this.createChatCompletion(options, retries - 1);
270+
}
271+
272+
throw new Error("Invalid response schema");
273+
}
274+
275+
if (this.enableCaching) {
276+
this.cache.set(
277+
cacheOptions,
278+
{
279+
...parsedData,
280+
},
281+
options.requestId,
282+
);
283+
}
284+
285+
return parsedData;
286+
}
287+
288+
if (this.enableCaching) {
289+
this.logger({
290+
category: "llm_cache",
291+
message: "caching response",
292+
level: 1,
293+
auxiliary: {
294+
requestId: {
295+
value: options.requestId,
296+
type: "string",
297+
},
298+
cacheOptions: {
299+
value: JSON.stringify(cacheOptions),
300+
type: "object",
301+
},
302+
response: {
303+
value: JSON.stringify(response),
304+
type: "object",
305+
},
306+
},
307+
});
308+
this.cache.set(cacheOptions, response, options.requestId);
309+
}
310+
311+
return response as T;
312+
}
313+
}

0 commit comments

Comments
 (0)