Skip to content

Commit c10bf0f

Browse files
authored
Add support for Workers AI in local mode (#4522)
* Add support for Workers AI in local mode * trigger build * Add support for AI binding in pages * Lint code * Fix pages binding
1 parent 5e67ea1 commit c10bf0f

File tree

16 files changed

+170
-8
lines changed

16 files changed

+170
-8
lines changed

.changeset/mean-jars-count.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"wrangler": minor
3+
---
4+
5+
Add support for Workers AI in local mode

fixtures/ai-app/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
dist

fixtures/ai-app/package.json

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"name": "ai-app",
3+
"version": "1.0.1",
4+
"private": true,
5+
"description": "",
6+
"license": "ISC",
7+
"author": "",
8+
"main": "src/index.js",
9+
"scripts": {
10+
"check:type": "tsc",
11+
"test": "vitest run",
12+
"test:watch": "vitest",
13+
"type:tests": "tsc -p ./tests/tsconfig.json"
14+
},
15+
"devDependencies": {
16+
"undici": "^5.23.0",
17+
"wrangler": "workspace:*",
18+
"@cloudflare/workers-tsconfig": "workspace:^",
19+
"@cloudflare/ai": "^1.0.35"
20+
}
21+
}

fixtures/ai-app/src/index.js

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
console.log("startup log");
2+
3+
export default {
4+
async fetch(request, env) {
5+
console.log("request log");
6+
7+
return Response.json({
8+
binding: env.AI,
9+
fetcher: env.AI.fetch.toString(),
10+
});
11+
},
12+
};

fixtures/ai-app/tests/index.test.ts

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import { resolve } from "path";
2+
import { fetch } from "undici";
3+
import { describe, it, beforeAll, afterAll } from "vitest";
4+
import { runWranglerDev } from "../../shared/src/run-wrangler-long-lived";
5+
6+
describe("'wrangler dev' correctly renders pages", () => {
7+
let ip: string,
8+
port: number,
9+
stop: (() => Promise<unknown>) | undefined,
10+
getOutput: () => string;
11+
12+
beforeAll(async () => {
13+
({ ip, port, stop, getOutput } = await runWranglerDev(
14+
resolve(__dirname, ".."),
15+
["--local", "--port=0", "--inspector-port=0"]
16+
));
17+
});
18+
19+
afterAll(async () => {
20+
await stop?.();
21+
});
22+
23+
it("ai binding is defined ", async ({ expect }) => {
24+
const response = await fetch(`http://${ip}:${port}/`);
25+
const content = await response.json();
26+
expect(content).toEqual({
27+
binding: {},
28+
fetcher: "function fetch() { [native code] }",
29+
});
30+
});
31+
});

fixtures/ai-app/tests/tsconfig.json

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"extends": "@cloudflare/workers-tsconfig/tsconfig.json",
3+
"compilerOptions": {
4+
"types": ["node"]
5+
},
6+
"include": ["**/*.ts", "../../../node-types.d.ts"]
7+
}

fixtures/ai-app/tsconfig.json

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"compilerOptions": {
3+
"target": "ES2020",
4+
"esModuleInterop": true,
5+
"module": "CommonJS",
6+
"lib": ["ES2020"],
7+
"types": ["node"],
8+
"skipLibCheck": true,
9+
"moduleResolution": "node",
10+
"noEmit": true
11+
},
12+
"include": ["tests", "../../node-types.d.ts"]
13+
}

fixtures/ai-app/vitest.config.ts

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import { defineConfig } from "vitest/config";
2+
3+
export default defineConfig({
4+
test: {
5+
testTimeout: 10_000,
6+
hookTimeout: 10_000,
7+
teardownTimeout: 10_000,
8+
useAtomics: true,
9+
},
10+
});

fixtures/ai-app/wrangler.toml

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
name = "ai-app"
2+
compatibility_date = "2023-11-21"
3+
4+
main = "src/index.js"
5+
6+
[ai]
7+
binding = "AI"

packages/wrangler/package.json

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
"xxhash-wasm": "^1.0.1"
118118
},
119119
"devDependencies": {
120+
"@cloudflare/ai": "^1.0.35",
120121
"@cloudflare/cli": "workspace:*",
121122
"@cloudflare/eslint-config-worker": "*",
122123
"@cloudflare/pages-shared": "workspace:^",

packages/wrangler/src/ai/fetcher.ts

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { Response } from "miniflare";
2+
import { performApiFetch } from "../cfetch/internal";
3+
import { getAccountId } from "../user";
4+
import type { Request } from "miniflare";
5+
6+
export async function AIFetcher(request: Request) {
7+
const accountId = await getAccountId();
8+
9+
request.headers.delete("Host");
10+
request.headers.delete("Content-Length");
11+
12+
const res = await performApiFetch(`/accounts/${accountId}/ai/run/proxy`, {
13+
method: "POST",
14+
headers: Object.fromEntries(request.headers.entries()),
15+
body: request.body,
16+
duplex: "half",
17+
});
18+
19+
return new Response(res.body, { status: res.status });
20+
}

packages/wrangler/src/api/dev.ts

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ export interface UnstableDevOptions {
5050
bucket_name: string;
5151
preview_bucket_name?: string;
5252
}[];
53+
ai?: {
54+
binding: string;
55+
};
5356
moduleRoot?: string;
5457
rules?: Rule[];
5558
logLevel?: "none" | "info" | "error" | "log" | "warn" | "debug"; // Specify logging level [choices: "debug", "info", "log", "warn", "error", "none"] [default: "log"]

packages/wrangler/src/dev.tsx

+5-1
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ export type AdditionalDevProps = {
326326
preview_bucket_name?: string;
327327
jurisdiction?: string;
328328
}[];
329+
ai?: {
330+
binding: string;
331+
};
329332
d1Databases?: Environment["d1_databases"];
330333
processEntrypoint?: boolean;
331334
additionalModules?: CfModule[];
@@ -832,6 +835,7 @@ function getBindingsAndAssetPaths(args: StartDevOptions, configParam: Config) {
832835
r2: args.r2,
833836
services: args.services,
834837
d1Databases: args.d1Databases,
838+
ai: args.ai,
835839
});
836840

837841
const maskedVars = maskVars(bindings, configParam);
@@ -893,7 +897,7 @@ function getBindings(
893897
wasm_modules: configParam.wasm_modules,
894898
text_blobs: configParam.text_blobs,
895899
browser: configParam.browser,
896-
ai: configParam.ai,
900+
ai: configParam.ai || args.ai,
897901
data_blobs: configParam.data_blobs,
898902
durable_objects: {
899903
bindings: [

packages/wrangler/src/dev/miniflare.ts

+6-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import assert from "node:assert";
22
import { realpathSync } from "node:fs";
33
import path from "node:path";
44
import { Log, LogLevel, TypedEventTarget, Mutex, Miniflare } from "miniflare";
5+
import { AIFetcher } from "../ai/fetcher";
56
import { ModuleTypeToRuleType } from "../deployment-bundle/module-collection";
67
import { withSourceURLs } from "../deployment-bundle/source-url";
78
import { getHttpsOptions } from "../https-options";
@@ -312,6 +313,10 @@ function buildBindingOptions(config: ConfigBundle) {
312313
.join("\n"),
313314
};
314315

316+
if (bindings.ai?.binding) {
317+
config.serviceBindings[bindings.ai.binding] = AIFetcher;
318+
}
319+
315320
const bindingOptions = {
316321
bindings: bindings.vars,
317322
textBlobBindings,
@@ -502,13 +507,7 @@ async function buildMiniflareOptions(
502507
logger.warn("Miniflare 3 does not support CRON triggers yet, ignoring...");
503508
}
504509

505-
if (config.bindings.ai) {
506-
logger.warn(
507-
"Workers AI is not currently supported in local mode. Please use --remote to work with it."
508-
);
509-
}
510-
511-
if (!config.bindings.ai && config.bindings.vectorize?.length) {
510+
if (config.bindings.vectorize?.length) {
512511
// TODO: add local support for Vectorize bindings (https://github.com/cloudflare/workers-sdk/issues/4360)
513512
logger.warn(
514513
"Vectorize bindings are not currently supported in local mode. Please use --remote if you are working with them."

packages/wrangler/src/pages/dev.ts

+6
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ export function Options(yargs: CommonYargsArgv) {
162162
type: "array",
163163
description: "R2 bucket to bind (--r2 R2_BINDING)",
164164
},
165+
ai: {
166+
type: "string",
167+
description: "AI to bind (--ai AI_BINDING)",
168+
},
165169
service: {
166170
type: "array",
167171
description: "Service to bind (--service SERVICE=SCRIPT_NAME)",
@@ -222,6 +226,7 @@ export const Handler = async ({
222226
do: durableObjects = [],
223227
d1: d1s = [],
224228
r2: r2s = [],
229+
ai,
225230
service: requestedServices = [],
226231
liveReload,
227232
localProtocol,
@@ -677,6 +682,7 @@ export const Handler = async ({
677682
return { binding, bucket_name: ref || binding.toString() };
678683
})
679684
.filter(Boolean) as AdditionalDevProps["r2"],
685+
ai: ai ? { binding: ai.toString() } : undefined,
680686
rules: usingWorkerDirectory
681687
? [
682688
{

pnpm-lock.yaml

+22
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)