|
| 1 | +import { |
| 2 | + UnsupportedModelError, |
| 3 | + UnsupportedModelProviderError, |
| 4 | +} from "@/types/stagehandErrors"; |
| 5 | +import { LanguageModel } from "ai"; |
1 | 6 | import { LogLine } from "../../types/log";
|
2 | 7 | import {
|
3 | 8 | AvailableModel,
|
4 | 9 | ClientOptions,
|
5 | 10 | ModelProvider,
|
6 | 11 | } from "../../types/model";
|
7 | 12 | import { LLMCache } from "../cache/LLMCache";
|
| 13 | +import { AISdkClient } from "./aisdk"; |
8 | 14 | import { AnthropicClient } from "./AnthropicClient";
|
9 | 15 | import { CerebrasClient } from "./CerebrasClient";
|
10 | 16 | import { GoogleClient } from "./GoogleClient";
|
11 | 17 | import { GroqClient } from "./GroqClient";
|
12 | 18 | import { LLMClient } from "./LLMClient";
|
13 | 19 | import { OpenAIClient } from "./OpenAIClient";
|
14 |
| -import { |
15 |
| - UnsupportedModelError, |
16 |
| - UnsupportedModelProviderError, |
17 |
| -} from "@/types/stagehandErrors"; |
| 20 | + |
| 21 | +function modelToProvider( |
| 22 | + modelName: AvailableModel | LanguageModel, |
| 23 | +): ModelProvider { |
| 24 | + if (typeof modelName === "string") { |
| 25 | + const provider = modelToProviderMap[modelName]; |
| 26 | + if (!provider) { |
| 27 | + throw new UnsupportedModelError(Object.keys(modelToProviderMap)); |
| 28 | + } |
| 29 | + return provider; |
| 30 | + } |
| 31 | + return "aisdk"; |
| 32 | +} |
18 | 33 |
|
19 | 34 | const modelToProviderMap: { [key in AvailableModel]: ModelProvider } = {
|
20 | 35 | "gpt-4.1": "openai",
|
@@ -81,53 +96,63 @@ export class LLMProvider {
|
81 | 96 | }
|
82 | 97 |
|
83 | 98 | getClient(
|
84 |
| - modelName: AvailableModel, |
| 99 | + modelName: AvailableModel | LanguageModel, |
85 | 100 | clientOptions?: ClientOptions,
|
86 | 101 | ): LLMClient {
|
87 |
| - const provider = modelToProviderMap[modelName]; |
| 102 | + const provider = modelToProvider(modelName); |
88 | 103 | if (!provider) {
|
89 | 104 | throw new UnsupportedModelError(Object.keys(modelToProviderMap));
|
90 | 105 | }
|
91 | 106 |
|
| 107 | + if (provider === "aisdk") { |
| 108 | + return new AISdkClient({ |
| 109 | + model: modelName as LanguageModel, |
| 110 | + logger: this.logger, |
| 111 | + enableCaching: this.enableCaching, |
| 112 | + cache: this.cache, |
| 113 | + }); |
| 114 | + } |
| 115 | + |
| 116 | + const availableModel = modelName as AvailableModel; |
92 | 117 | switch (provider) {
|
93 | 118 | case "openai":
|
94 | 119 | return new OpenAIClient({
|
95 | 120 | logger: this.logger,
|
96 | 121 | enableCaching: this.enableCaching,
|
97 | 122 | cache: this.cache,
|
98 |
| - modelName, |
| 123 | + modelName: availableModel, |
99 | 124 | clientOptions,
|
100 | 125 | });
|
101 | 126 | case "anthropic":
|
102 | 127 | return new AnthropicClient({
|
103 | 128 | logger: this.logger,
|
104 | 129 | enableCaching: this.enableCaching,
|
105 | 130 | cache: this.cache,
|
106 |
| - modelName, |
| 131 | + modelName: availableModel, |
107 | 132 | clientOptions,
|
108 | 133 | });
|
109 | 134 | case "cerebras":
|
110 | 135 | return new CerebrasClient({
|
111 | 136 | logger: this.logger,
|
112 | 137 | enableCaching: this.enableCaching,
|
113 | 138 | cache: this.cache,
|
114 |
| - modelName, |
| 139 | + modelName: availableModel, |
115 | 140 | clientOptions,
|
116 | 141 | });
|
117 | 142 | case "groq":
|
118 | 143 | return new GroqClient({
|
119 | 144 | logger: this.logger,
|
120 | 145 | enableCaching: this.enableCaching,
|
121 | 146 | cache: this.cache,
|
122 |
| - modelName, |
| 147 | + modelName: availableModel, |
123 | 148 | clientOptions,
|
124 | 149 | });
|
125 | 150 | case "google":
|
126 | 151 | return new GoogleClient({
|
127 | 152 | logger: this.logger,
|
128 | 153 | enableCaching: this.enableCaching,
|
129 | 154 | cache: this.cache,
|
130 |
| - modelName, |
| 155 | + modelName: availableModel, |
131 | 156 | clientOptions,
|
132 | 157 | });
|
133 | 158 | default:
|
|
0 commit comments