Skip to content

Commit 5c4686f

Browse files
authored
[wasm] Implement MINT_SWITCH opcode in jiterpreter (#107423)
* Implement MINT_SWITCH opcode (without support for backward jumps) * Introduce runtime option for max switch size (set to 0 to disable switches) * Disable trace generation once the trace table fills up, since there's no point to it
1 parent a349912 commit 5c4686f

File tree

6 files changed

+194
-46
lines changed

6 files changed

+194
-46
lines changed

src/mono/browser/runtime/jiterpreter-enums.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ export const enum JiterpCounter {
2121
BackBranchesNotEmitted,
2222
ElapsedGenerationMs,
2323
ElapsedCompilationMs,
24+
SwitchTargetsOk,
25+
SwitchTargetsFailed,
2426
}
2527

2628
// keep in sync with jiterpreter.c, see mono_jiterp_get_member_offset
@@ -127,7 +129,8 @@ export const enum BailoutReason {
127129
Icall,
128130
UnexpectedRetIp,
129131
LeaveCheck,
130-
Switch,
132+
SwitchSize,
133+
SwitchTarget,
131134
}
132135

133136
export const BailoutReasonNames = [
@@ -158,7 +161,8 @@ export const BailoutReasonNames = [
158161
"Icall",
159162
"UnexpectedRetIp",
160163
"LeaveCheck",
161-
"Switch",
164+
"SwitchSize",
165+
"SwitchTarget",
162166
];
163167

164168
export const enum JitQueue {

src/mono/browser/runtime/jiterpreter-support.ts

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,14 @@ type CfgBranch = {
11501150
branchType: CfgBranchType;
11511151
}
11521152

1153-
type CfgSegment = CfgBlob | CfgBranchBlockHeader | CfgBranch;
1153+
type CfgJumpTable = {
1154+
type: "jump-table";
1155+
from: MintOpcodePtr;
1156+
targets: MintOpcodePtr[];
1157+
fallthrough: MintOpcodePtr;
1158+
}
1159+
1160+
type CfgSegment = CfgBlob | CfgBranchBlockHeader | CfgBranch | CfgJumpTable;
11541161

11551162
export const enum CfgBranchType {
11561163
Unconditional,
@@ -1278,6 +1285,23 @@ class Cfg {
12781285
}
12791286
}
12801287

1288+
// It's the caller's responsibility to wrap this in a block and follow it with a bailout!
1289+
jumpTable (targets: MintOpcodePtr[], fallthrough: MintOpcodePtr) {
1290+
this.appendBlob();
1291+
this.segments.push({
1292+
type: "jump-table",
1293+
from: this.ip,
1294+
targets,
1295+
fallthrough,
1296+
});
1297+
// opcode, length, fallthrough (approximate)
1298+
this.overheadBytes += 4;
1299+
// length of branch depths (approximate)
1300+
this.overheadBytes += targets.length;
1301+
// bailout for missing targets (approximate)
1302+
this.overheadBytes += 24;
1303+
}
1304+
12811305
emitBlob (segment: CfgBlob, source: Uint8Array) {
12821306
// mono_log_info(`segment @${(<any>segment.ip).toString(16)} ${segment.start}-${segment.start + segment.length}`);
12831307
const view = source.subarray(segment.start, segment.start + segment.length);
@@ -1415,6 +1439,38 @@ class Cfg {
14151439
this.blockStack.shift();
14161440
break;
14171441
}
1442+
case "jump-table": {
1443+
// Our caller wrapped us in a block and put a missing target bailout after us
1444+
const offset = 1;
1445+
// The selector was already loaded onto the wasm stack before cfg.jumpTable was called,
1446+
// so we just need to generate a br_table
1447+
this.builder.appendU8(WasmOpcode.br_table);
1448+
this.builder.appendULeb(segment.targets.length);
1449+
for (const target of segment.targets) {
1450+
const indexInStack = this.blockStack.indexOf(target);
1451+
if (indexInStack >= 0) {
1452+
modifyCounter(JiterpCounter.SwitchTargetsOk, 1);
1453+
this.builder.appendULeb(indexInStack + offset);
1454+
} else {
1455+
modifyCounter(JiterpCounter.SwitchTargetsFailed, 1);
1456+
if (this.trace > 0)
1457+
mono_log_info(`Switch target ${target} not found in block stack ${this.blockStack}`);
1458+
this.builder.appendULeb(0);
1459+
}
1460+
}
1461+
const fallthroughIndex = this.blockStack.indexOf(segment.fallthrough);
1462+
if (fallthroughIndex >= 0) {
1463+
modifyCounter(JiterpCounter.SwitchTargetsOk, 1);
1464+
this.builder.appendULeb(fallthroughIndex + offset);
1465+
} else {
1466+
modifyCounter(JiterpCounter.SwitchTargetsFailed, 1);
1467+
if (this.trace > 0)
1468+
mono_log_info(`Switch fallthrough ${segment.fallthrough} not found in block stack ${this.blockStack}`);
1469+
this.builder.appendULeb(0);
1470+
}
1471+
this.builder.appendU8(WasmOpcode.unreachable);
1472+
break;
1473+
}
14181474
case "branch": {
14191475
const lookupTarget = segment.isBackward ? dispatchIp : segment.target;
14201476
let indexInStack = this.blockStack.indexOf(lookupTarget),
@@ -1965,6 +2021,7 @@ export type JiterpreterOptions = {
19652021
tableSize: number;
19662022
aotTableSize: number;
19672023
maxModuleSize: number;
2024+
maxSwitchSize: number;
19682025
}
19692026

19702027
const optionNames: { [jsName: string]: string } = {
@@ -2002,6 +2059,7 @@ const optionNames: { [jsName: string]: string } = {
20022059
"tableSize": "jiterpreter-table-size",
20032060
"aotTableSize": "jiterpreter-aot-table-size",
20042061
"maxModuleSize": "jiterpreter-max-module-size",
2062+
"maxSwitchSize": "jiterpreter-max-switch-size",
20052063
};
20062064

20072065
let optionsVersion = -1;

src/mono/browser/runtime/jiterpreter-trace-generator.ts

Lines changed: 112 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,34 @@ function getOpcodeLengthU16 (ip: MintOpcodePtr, opcode: MintOpcode) {
184184
}
185185
}
186186

187+
function decodeSwitch (ip: MintOpcodePtr) : MintOpcodePtr[] {
188+
mono_assert(getU16(ip) === MintOpcode.MINT_SWITCH, "decodeSwitch called on a non-switch");
189+
const n = getArgU32(ip, 2);
190+
const result = [];
191+
/*
192+
guint32 val = LOCAL_VAR (ip [1], guint32);
193+
guint32 n = READ32 (ip + 2);
194+
ip += 4;
195+
if (val < n) {
196+
ip += 2 * val;
197+
int offset = READ32 (ip);
198+
ip += offset;
199+
} else {
200+
ip += 2 * n;
201+
}
202+
*/
203+
// mono_log_info(`switch[${n}] @${ip}`);
204+
for (let i = 0; i < n; i++) {
205+
const base = <any>ip + 8 + (4 * i),
206+
offset = getU32_unaligned(base),
207+
target = base + (offset * 2);
208+
// mono_log_info(` ${i} -> ${target}`);
209+
result.push(target);
210+
}
211+
212+
return result;
213+
}
214+
187215
// Perform a quick scan through the opcodes potentially in this trace to build a table of
188216
// backwards branch targets, compatible with the layout of the old one that was generated in C.
189217
// We do this here to match the exact way that the jiterp calculates branch targets, since
@@ -205,47 +233,60 @@ export function generateBackwardBranchTable (
205233
const opcode = <MintOpcode>getU16(ip);
206234
const opLengthU16 = getOpcodeLengthU16(ip, opcode);
207235

208-
// Any opcode with a branch argtype will have a decoded displacement, even if we don't
209-
// implement the opcode. Everything else will return undefined here and be skipped
210-
const displacement = getBranchDisplacement(ip, opcode);
211-
if (typeof (displacement) !== "number") {
212-
ip += <any>(opLengthU16 * 2);
213-
continue;
214-
}
215-
216-
// These checks shouldn't fail unless memory is corrupted or something is wrong with the decoder.
217-
// We don't want to cause decoder bugs to make the application exit, though - graceful degradation.
218-
if (displacement === 0) {
219-
mono_log_info(`opcode @${ip} branch target is self. aborting backbranch table generation`);
220-
break;
221-
}
236+
if (opcode === MintOpcode.MINT_SWITCH) {
237+
// FIXME: Once the cfg supports back-branches in jump tables, uncomment this to
238+
// insert the back-branch targets into the table so they'll actually work
239+
/*
240+
const switchTable = decodeSwitch(ip);
241+
for (const target of switchTable) {
242+
const rtarget16 = (<any>target - <any>startOfBody) / 2;
243+
if (target < ip)
244+
table.push(rtarget16);
245+
}
246+
*/
247+
} else {
248+
// Any opcode with a branch argtype will have a decoded displacement, even if we don't
249+
// implement the opcode. Everything else will return undefined here and be skipped
250+
const displacement = getBranchDisplacement(ip, opcode);
251+
if (typeof (displacement) !== "number") {
252+
ip += <any>(opLengthU16 * 2);
253+
continue;
254+
}
222255

223-
// Only record *backward* branches
224-
// We will filter this down further in the Cfg because it takes note of which branches it sees,
225-
// but it is also beneficial to have a null table (further down) due to seeing no potential
226-
// back branch targets at all, as it allows the Cfg to skip additional code generation entirely
227-
// if it knows there will never be any backwards branches in a given trace
228-
if (displacement < 0) {
229-
const rtarget16 = rip16 + (displacement);
230-
if (rtarget16 < 0) {
231-
mono_log_info(`opcode @${ip}'s displacement of ${displacement} goes before body: ${rtarget16}. aborting backbranch table generation`);
256+
// These checks shouldn't fail unless memory is corrupted or something is wrong with the decoder.
257+
// We don't want to cause decoder bugs to make the application exit, though - graceful degradation.
258+
if (displacement === 0) {
259+
mono_log_info(`opcode @${ip} branch target is self. aborting backbranch table generation`);
232260
break;
233261
}
234262

235-
// If the relative target is before the start of the trace, don't record it.
236-
// The trace will be unable to successfully branch to it so it would just make the table bigger.
237-
if (rtarget16 >= rbase16)
238-
table.push(rtarget16);
239-
}
263+
// Only record *backward* branches
264+
// We will filter this down further in the Cfg because it takes note of which branches it sees,
265+
// but it is also beneficial to have a null table (further down) due to seeing no potential
266+
// back branch targets at all, as it allows the Cfg to skip additional code generation entirely
267+
// if it knows there will never be any backwards branches in a given trace
268+
if (displacement < 0) {
269+
const rtarget16 = rip16 + (displacement);
270+
if (rtarget16 < 0) {
271+
mono_log_info(`opcode @${ip}'s displacement of ${displacement} goes before body: ${rtarget16}. aborting backbranch table generation`);
272+
break;
273+
}
240274

241-
switch (opcode) {
242-
case MintOpcode.MINT_CALL_HANDLER:
243-
case MintOpcode.MINT_CALL_HANDLER_S:
244-
// While this formally isn't a backward branch target, we want to record
245-
// the offset of its following instruction so that the jiterpreter knows
246-
// to generate the necessary dispatch code to enable branching back to it.
247-
table.push(rip16 + opLengthU16);
248-
break;
275+
// If the relative target is before the start of the trace, don't record it.
276+
// The trace will be unable to successfully branch to it so it would just make the table bigger.
277+
if (rtarget16 >= rbase16)
278+
table.push(rtarget16);
279+
}
280+
281+
switch (opcode) {
282+
case MintOpcode.MINT_CALL_HANDLER:
283+
case MintOpcode.MINT_CALL_HANDLER_S:
284+
// While this formally isn't a backward branch target, we want to record
285+
// the offset of its following instruction so that the jiterpreter knows
286+
// to generate the necessary dispatch code to enable branching back to it.
287+
table.push(rip16 + opLengthU16);
288+
break;
289+
}
249290
}
250291

251292
ip += <any>(opLengthU16 * 2);
@@ -399,7 +440,7 @@ export function generateWasmBody (
399440

400441
switch (opcode) {
401442
case MintOpcode.MINT_SWITCH: {
402-
if (!emit_switch(builder, ip))
443+
if (!emit_switch(builder, ip, exitOpcodeCounter))
403444
ip = abort;
404445
break;
405446
}
@@ -4036,7 +4077,39 @@ function emit_atomics (
40364077
return false;
40374078
}
40384079

4039-
function emit_switch (builder: WasmBuilder, ip: MintOpcodePtr) : boolean {
4040-
append_bailout(builder, ip, BailoutReason.Switch);
4080+
function emit_switch (builder: WasmBuilder, ip: MintOpcodePtr, exitOpcodeCounter: number) : boolean {
4081+
const lengthU16 = getOpcodeLengthU16(ip, MintOpcode.MINT_SWITCH),
4082+
table = decodeSwitch(ip);
4083+
let failed = false;
4084+
4085+
if (table.length > builder.options.maxSwitchSize) {
4086+
failed = true;
4087+
} else {
4088+
// Record all the switch's forward branch targets.
4089+
// If it contains any back branches they will bailout at runtime.
4090+
for (const target of table) {
4091+
if (target > ip)
4092+
builder.branchTargets.add(target);
4093+
}
4094+
}
4095+
4096+
if (failed) {
4097+
modifyCounter(JiterpCounter.SwitchTargetsFailed, table.length);
4098+
append_bailout(builder, ip, BailoutReason.SwitchSize);
4099+
return true;
4100+
}
4101+
4102+
const fallthrough = <any>ip + (lengthU16 * 2);
4103+
builder.branchTargets.add(fallthrough);
4104+
4105+
// Jump table needs a block so it can `br 0` for missing targets
4106+
builder.block();
4107+
// Load selector
4108+
append_ldloc(builder, getArgU16(ip, 1), WasmOpcode.i32_load);
4109+
// Dispatch
4110+
builder.cfg.jumpTable(table, fallthrough);
4111+
// Missing target
4112+
builder.endBlock();
4113+
append_exit(builder, ip, exitOpcodeCounter, BailoutReason.SwitchTarget);
40414114
return true;
40424115
}

src/mono/browser/runtime/jiterpreter.ts

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ export const callTargetCounts: { [method: number]: number } = {};
7979

8080
export let mostRecentTrace: InstrumentedTraceState | undefined;
8181
export let mostRecentOptions: JiterpreterOptions | undefined = undefined;
82+
export let traceTableIsFull = false;
8283

8384
// You can disable an opcode for debugging purposes by adding it to this list,
8485
// instead of aborting the trace it will insert a bailout instead. This means that you will
@@ -861,6 +862,11 @@ function generate_wasm (
861862
idx = presetFunctionPointer;
862863
} else {
863864
idx = addWasmFunctionPointer(JiterpreterTable.Trace, <any>fn);
865+
if (idx === 0) {
866+
// Failed to add function pointer because trace table is full. Disable future
867+
// trace generation to reduce CPU usage.
868+
traceTableIsFull = true;
869+
}
864870
}
865871
if (trace >= 2)
866872
mono_log_info(`${traceName} -> fn index ${idx}`);
@@ -984,6 +990,8 @@ export function mono_interp_tier_prepare_jiterpreter (
984990
return JITERPRETER_NOT_JITTED;
985991
else if (mostRecentOptions.wasmBytesLimit <= getCounter(JiterpCounter.BytesGenerated))
986992
return JITERPRETER_NOT_JITTED;
993+
else if (traceTableIsFull)
994+
return JITERPRETER_NOT_JITTED;
987995

988996
let info = traceInfo[index];
989997

@@ -1078,7 +1086,9 @@ export function jiterpreter_dump_stats (concise?: boolean): void {
10781086
traceCandidates = getCounter(JiterpCounter.TraceCandidates),
10791087
bytesGenerated = getCounter(JiterpCounter.BytesGenerated),
10801088
elapsedGenerationMs = getCounter(JiterpCounter.ElapsedGenerationMs),
1081-
elapsedCompilationMs = getCounter(JiterpCounter.ElapsedCompilationMs);
1089+
elapsedCompilationMs = getCounter(JiterpCounter.ElapsedCompilationMs),
1090+
switchTargetsOk = getCounter(JiterpCounter.SwitchTargetsOk),
1091+
switchTargetsFailed = getCounter(JiterpCounter.SwitchTargetsFailed);
10821092

10831093
const backBranchHitRate = (backBranchesEmitted / (backBranchesEmitted + backBranchesNotEmitted)) * 100,
10841094
tracesRejected = cwraps.mono_jiterp_get_rejected_trace_count(),
@@ -1089,8 +1099,8 @@ export function jiterpreter_dump_stats (concise?: boolean): void {
10891099
mostRecentOptions.directJitCalls ? `direct jit calls: ${directJitCallsCompiled} (${(directJitCallsCompiled / jitCallsCompiled * 100).toFixed(1)}%)` : "direct jit calls: off"
10901100
) : "";
10911101

1092-
mono_log_info(`// jitted ${bytesGenerated} bytes; ${tracesCompiled} traces (${(tracesCompiled / traceCandidates * 100).toFixed(1)}%) (${tracesRejected} rejected); ${jitCallsCompiled} jit_calls; ${entryWrappersCompiled} interp_entries`);
1093-
mono_log_info(`// cknulls eliminated: ${nullChecksEliminatedText}, fused: ${nullChecksFusedText}; back-branches ${backBranchesEmittedText}; ${directJitCallsText}`);
1102+
mono_log_info(`// jitted ${bytesGenerated}b; ${tracesCompiled} traces (${(tracesCompiled / traceCandidates * 100).toFixed(1)}%) (${tracesRejected} rejected); ${jitCallsCompiled} jit_calls; ${entryWrappersCompiled} interp_entries`);
1103+
mono_log_info(`// cknulls pruned: ${nullChecksEliminatedText}, fused: ${nullChecksFusedText}; back-brs ${backBranchesEmittedText}; switch tgts ${switchTargetsOk}/${switchTargetsFailed + switchTargetsOk}; ${directJitCallsText}`);
10941104
mono_log_info(`// time: ${elapsedGenerationMs | 0}ms generating, ${elapsedCompilationMs | 0}ms compiling wasm.`);
10951105
if (concise)
10961106
return;

src/mono/mono/mini/interp/jiterpreter.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,9 @@ enum {
12571257
JITERP_COUNTER_BACK_BRANCHES_NOT_EMITTED,
12581258
JITERP_COUNTER_ELAPSED_GENERATION,
12591259
JITERP_COUNTER_ELAPSED_COMPILATION,
1260-
JITERP_COUNTER_MAX = JITERP_COUNTER_ELAPSED_COMPILATION
1260+
JITERP_COUNTER_SWITCH_TARGETS_OK,
1261+
JITERP_COUNTER_SWITCH_TARGETS_FAILED,
1262+
JITERP_COUNTER_MAX = JITERP_COUNTER_SWITCH_TARGETS_FAILED
12611263
};
12621264

12631265
#define JITERP_COUNTER_UNIT 100

src/mono/mono/utils/options-def.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ DEFINE_INT(jiterpreter_table_size, "jiterpreter-table-size", 6 * 1024, "Size of
177177
// FIXME: In the future if we find a way to reduce the number of unique tables we can raise this constant
178178
DEFINE_INT(jiterpreter_aot_table_size, "jiterpreter-aot-table-size", 3 * 1024, "Size of the jiterpreter AOT trampoline function tables")
179179
DEFINE_INT(jiterpreter_max_module_size, "jiterpreter-max-module-size", 4080, "Size limit for jiterpreter generated WASM modules")
180+
DEFINE_INT(jiterpreter_max_switch_size, "jiterpreter-max-switch-size", 24, "Size limit for jiterpreter switch opcodes (0 to disable)")
180181
#endif // HOST_BROWSER
181182

182183
#if defined(TARGET_WASM) || defined(TARGET_IOS) || defined(TARGET_TVOS) || defined (TARGET_MACCAT)

0 commit comments

Comments
 (0)