Skip to content

Commit d38f247

Browse files
committed
make ai sdk native
1 parent 5c6d2cf commit d38f247

File tree

10 files changed

+377
-205
lines changed

10 files changed

+377
-205
lines changed

evals/index.eval.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ import { StagehandEvalError } from "@/types/stagehandErrors";
3434
import { CustomOpenAIClient } from "@/examples/external_clients/customOpenAI";
3535
import OpenAI from "openai";
3636
import { initStagehand } from "./initStagehand";
37-
import { AISdkClient } from "@/examples/external_clients/aisdk";
3837
import { google } from "@ai-sdk/google";
3938
import { anthropic } from "@ai-sdk/anthropic";
4039
import { groq } from "@ai-sdk/groq";
4140
import { cerebras } from "@ai-sdk/cerebras";
4241
import { openai } from "@ai-sdk/openai";
42+
import { AISdkClient } from "@/lib/llm/aisdk";
4343
dotenv.config();
4444

4545
/**

examples/ai_sdk_example.ts

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
import { openai } from "@ai-sdk/openai";
21
import { Stagehand } from "@/dist";
3-
import { AISdkClient } from "./external_clients/aisdk";
42
import StagehandConfig from "@/stagehand.config";
3+
import { openai } from "@ai-sdk/openai";
54
import { z } from "zod";
65

76
async function example() {
87
const stagehand = new Stagehand({
98
...StagehandConfig,
10-
llmClient: new AISdkClient({
11-
model: openai("gpt-4o"),
12-
}),
9+
modelName: openai("gpt-4o"),
1310
});
1411

1512
await stagehand.init();

examples/external_clients/aisdk.ts

-122
This file was deleted.

lib/index.ts

+9-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import {
4545
MissingEnvironmentVariableError,
4646
UnsupportedModelError,
4747
} from "../types/stagehandErrors";
48+
import { LanguageModel } from "ai";
4849

4950
dotenv.config({ path: ".env" });
5051

@@ -384,7 +385,7 @@ export class Stagehand {
384385
public llmClient: LLMClient;
385386
public readonly userProvidedInstructions?: string;
386387
private usingAPI: boolean;
387-
private modelName: AvailableModel;
388+
private modelName: AvailableModel | LanguageModel;
388389
public apiClient: StagehandAPI | undefined;
389390
public readonly waitForCaptchaSolves: boolean;
390391
private localBrowserLaunchOptions?: LocalBrowserLaunchOptions;
@@ -656,18 +657,23 @@ export class Stagehand {
656657
projectId: this.projectId,
657658
logger: this.logger,
658659
});
660+
659661
const modelApiKey =
662+
// @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel
660663
LLMProvider.getModelProvider(this.modelName) === "openai"
661664
? process.env.OPENAI_API_KEY || this.llmClient.clientOptions.apiKey
662-
: LLMProvider.getModelProvider(this.modelName) === "anthropic"
665+
: // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel
666+
LLMProvider.getModelProvider(this.modelName) === "anthropic"
663667
? process.env.ANTHROPIC_API_KEY ||
664668
this.llmClient.clientOptions.apiKey
665-
: LLMProvider.getModelProvider(this.modelName) === "google"
669+
: // @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel
670+
LLMProvider.getModelProvider(this.modelName) === "google"
666671
? process.env.GOOGLE_API_KEY ||
667672
this.llmClient.clientOptions.apiKey
668673
: undefined;
669674

670675
const { sessionId } = await this.apiClient.init({
676+
// @ts-expect-error - this is a temporary fix to allow the modelName to be a LanguageModel
671677
modelName: this.modelName,
672678
modelApiKey: modelApiKey,
673679
domSettleTimeoutMs: this.domSettleTimeoutMs,

lib/llm/LLMProvider.ts

+36-11
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,35 @@
1+
import {
2+
UnsupportedModelError,
3+
UnsupportedModelProviderError,
4+
} from "@/types/stagehandErrors";
5+
import { LanguageModel } from "ai";
16
import { LogLine } from "../../types/log";
27
import {
38
AvailableModel,
49
ClientOptions,
510
ModelProvider,
611
} from "../../types/model";
712
import { LLMCache } from "../cache/LLMCache";
13+
import { AISdkClient } from "./aisdk";
814
import { AnthropicClient } from "./AnthropicClient";
915
import { CerebrasClient } from "./CerebrasClient";
1016
import { GoogleClient } from "./GoogleClient";
1117
import { GroqClient } from "./GroqClient";
1218
import { LLMClient } from "./LLMClient";
1319
import { OpenAIClient } from "./OpenAIClient";
14-
import {
15-
UnsupportedModelError,
16-
UnsupportedModelProviderError,
17-
} from "@/types/stagehandErrors";
20+
21+
function modelToProvider(
22+
modelName: AvailableModel | LanguageModel,
23+
): ModelProvider {
24+
if (typeof modelName === "string") {
25+
const provider = modelToProviderMap[modelName];
26+
if (!provider) {
27+
throw new UnsupportedModelError(Object.keys(modelToProviderMap));
28+
}
29+
return provider;
30+
}
31+
return "aisdk";
32+
}
1833

1934
const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = {
2035
"gpt-4.1": "openai",
@@ -81,53 +96,63 @@ export class LLMProvider {
8196
}
8297

8398
getClient(
84-
modelName: AvailableModel,
99+
modelName: AvailableModel | LanguageModel,
85100
clientOptions?: ClientOptions,
86101
): LLMClient {
87-
const provider = modelToProviderMap[modelName];
102+
const provider = modelToProvider(modelName);
88103
if (!provider) {
89104
throw new UnsupportedModelError(Object.keys(modelToProviderMap));
90105
}
91106

107+
if (provider === "aisdk") {
108+
return new AISdkClient({
109+
model: modelName as LanguageModel,
110+
logger: this.logger,
111+
enableCaching: this.enableCaching,
112+
cache: this.cache,
113+
});
114+
}
115+
116+
const availableModel = modelName as AvailableModel;
92117
switch (provider) {
93118
case "openai":
94119
return new OpenAIClient({
95120
logger: this.logger,
96121
enableCaching: this.enableCaching,
97122
cache: this.cache,
98-
modelName,
123+
modelName: availableModel,
99124
clientOptions,
100125
});
101126
case "anthropic":
102127
return new AnthropicClient({
103128
logger: this.logger,
104129
enableCaching: this.enableCaching,
105130
cache: this.cache,
106-
modelName,
131+
modelName: availableModel,
107132
clientOptions,
108133
});
109134
case "cerebras":
110135
return new CerebrasClient({
111136
logger: this.logger,
112137
enableCaching: this.enableCaching,
113138
cache: this.cache,
114-
modelName,
139+
modelName: availableModel,
115140
clientOptions,
116141
});
117142
case "groq":
118143
return new GroqClient({
119144
logger: this.logger,
120145
enableCaching: this.enableCaching,
121146
cache: this.cache,
122-
modelName,
147+
modelName: availableModel,
123148
clientOptions,
124149
});
125150
case "google":
126151
return new GoogleClient({
127152
logger: this.logger,
128153
enableCaching: this.enableCaching,
129154
cache: this.cache,
130-
modelName,
155+
modelName: availableModel,
131156
clientOptions,
132157
});
133158
default:

0 commit comments

Comments
 (0)