Skip to content

Commit 0501962

Browse files
authored
Merge pull request #11321 from Luukdegram/wasm-overflow
stage2: wasm - Implement overflow arithmetic
2 parents d15bbeb + e1bb096 commit 0501962

File tree

2 files changed

+156
-32
lines changed

2 files changed

+156
-32
lines changed

src/arch/wasm/CodeGen.zig

Lines changed: 156 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,10 +1303,16 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
13031303
.bool_and => self.airBinOp(inst, .@"and"),
13041304
.bool_or => self.airBinOp(inst, .@"or"),
13051305
.rem => self.airBinOp(inst, .rem),
1306-
.shl, .shl_exact => self.airBinOp(inst, .shl),
1306+
.shl => self.airWrapBinOp(inst, .shl),
1307+
.shl_exact => self.airBinOp(inst, .shl),
13071308
.shr, .shr_exact => self.airBinOp(inst, .shr),
13081309
.xor => self.airBinOp(inst, .xor),
13091310

1311+
.add_with_overflow => self.airBinOpOverflow(inst, .add),
1312+
.sub_with_overflow => self.airBinOpOverflow(inst, .sub),
1313+
.shl_with_overflow => self.airBinOpOverflow(inst, .shl),
1314+
.mul_with_overflow => self.airBinOpOverflow(inst, .mul),
1315+
13101316
.cmp_eq => self.airCmp(inst, .eq),
13111317
.cmp_gte => self.airCmp(inst, .gte),
13121318
.cmp_gt => self.airCmp(inst, .gt),
@@ -1461,13 +1467,6 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
14611467
.atomic_rmw,
14621468
.tag_name,
14631469
.mul_add,
1464-
1465-
// For these 4, probably best to wait until https://github.com/ziglang/zig/issues/10248
1466-
// is implemented in the frontend before implementing them here in the wasm backend.
1467-
.add_with_overflow,
1468-
.sub_with_overflow,
1469-
.mul_with_overflow,
1470-
.shl_with_overflow,
14711470
=> |tag| return self.fail("TODO: Implement wasm inst: {s}", .{@tagName(tag)}),
14721471
};
14731472
}
@@ -1754,24 +1753,28 @@ fn airBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
17541753
const lhs = try self.resolveInst(bin_op.lhs);
17551754
const rhs = try self.resolveInst(bin_op.rhs);
17561755
const operand_ty = self.air.typeOfIndex(inst);
1756+
const ty = self.air.typeOf(bin_op.lhs);
17571757

17581758
if (isByRef(operand_ty, self.target)) {
17591759
return self.fail("TODO: Implement binary operation for type: {}", .{operand_ty.fmtDebug()});
17601760
}
17611761

1762+
return self.binOp(lhs, rhs, ty, op);
1763+
}
1764+
1765+
fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue {
17621766
try self.emitWValue(lhs);
17631767
try self.emitWValue(rhs);
17641768

1765-
const bin_ty = self.air.typeOf(bin_op.lhs);
17661769
const opcode: wasm.Opcode = buildOpcode(.{
17671770
.op = op,
1768-
.valtype1 = typeToValtype(bin_ty, self.target),
1769-
.signedness = if (bin_ty.isSignedInt()) .signed else .unsigned,
1771+
.valtype1 = typeToValtype(ty, self.target),
1772+
.signedness = if (ty.isSignedInt()) .signed else .unsigned,
17701773
});
17711774
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
17721775

17731776
// save the result in a temporary
1774-
const bin_local = try self.allocLocal(bin_ty);
1777+
const bin_local = try self.allocLocal(ty);
17751778
try self.addLabel(.local_set, bin_local.local);
17761779
return bin_local;
17771780
}
@@ -1781,18 +1784,21 @@ fn airWrapBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
17811784
const lhs = try self.resolveInst(bin_op.lhs);
17821785
const rhs = try self.resolveInst(bin_op.rhs);
17831786

1787+
return self.wrapBinOp(lhs, rhs, self.air.typeOf(bin_op.lhs), op);
1788+
}
1789+
1790+
fn wrapBinOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue {
17841791
try self.emitWValue(lhs);
17851792
try self.emitWValue(rhs);
17861793

1787-
const bin_ty = self.air.typeOf(bin_op.lhs);
17881794
const opcode: wasm.Opcode = buildOpcode(.{
17891795
.op = op,
1790-
.valtype1 = typeToValtype(bin_ty, self.target),
1791-
.signedness = if (bin_ty.isSignedInt()) .signed else .unsigned,
1796+
.valtype1 = typeToValtype(ty, self.target),
1797+
.signedness = if (ty.isSignedInt()) .signed else .unsigned,
17921798
});
17931799
try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
17941800

1795-
const int_info = bin_ty.intInfo(self.target);
1801+
const int_info = ty.intInfo(self.target);
17961802
const bitsize = int_info.bits;
17971803
const is_signed = int_info.signedness == .signed;
17981804
// if target type bitsize is x < 32 and 32 > x < 64, we perform
@@ -1820,7 +1826,7 @@ fn airWrapBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
18201826
}
18211827

18221828
// save the result in a temporary
1823-
const bin_local = try self.allocLocal(bin_ty);
1829+
const bin_local = try self.allocLocal(ty);
18241830
try self.addLabel(.local_set, bin_local.local);
18251831
return bin_local;
18261832
}
@@ -2202,18 +2208,21 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: std.math.CompareOperator) Inner
22022208
const lhs = try self.resolveInst(bin_op.lhs);
22032209
const rhs = try self.resolveInst(bin_op.rhs);
22042210
const operand_ty = self.air.typeOf(bin_op.lhs);
2211+
return self.cmp(lhs, rhs, operand_ty, op);
2212+
}
22052213

2206-
if (operand_ty.zigTypeTag() == .Optional and !operand_ty.isPtrLikeOptional()) {
2214+
fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOperator) InnerError!WValue {
2215+
if (ty.zigTypeTag() == .Optional and !ty.isPtrLikeOptional()) {
22072216
var buf: Type.Payload.ElemType = undefined;
2208-
const payload_ty = operand_ty.optionalChild(&buf);
2217+
const payload_ty = ty.optionalChild(&buf);
22092218
if (payload_ty.hasRuntimeBitsIgnoreComptime()) {
22102219
// When we hit this case, we must check the value of optionals
22112220
// that are not pointers. This means first checking against non-null for
22122221
// both lhs and rhs, as well as checking the payload are matching of lhs and rhs
2213-
return self.cmpOptionals(lhs, rhs, operand_ty, op);
2222+
return self.cmpOptionals(lhs, rhs, ty, op);
22142223
}
2215-
} else if (isByRef(operand_ty, self.target)) {
2216-
return self.cmpBigInt(lhs, rhs, operand_ty, op);
2224+
} else if (isByRef(ty, self.target)) {
2225+
return self.cmpBigInt(lhs, rhs, ty, op);
22172226
}
22182227

22192228
// ensure that when we compare pointers, we emit
@@ -2229,13 +2238,13 @@ fn airCmp(self: *Self, inst: Air.Inst.Index, op: std.math.CompareOperator) Inner
22292238

22302239
const signedness: std.builtin.Signedness = blk: {
22312240
// by default we tell the operand type is unsigned (i.e. bools and enum values)
2232-
if (operand_ty.zigTypeTag() != .Int) break :blk .unsigned;
2241+
if (ty.zigTypeTag() != .Int) break :blk .unsigned;
22332242

22342243
// incase of an actual integer, we emit the correct signedness
2235-
break :blk operand_ty.intInfo(self.target).signedness;
2244+
break :blk ty.intInfo(self.target).signedness;
22362245
};
22372246
const opcode: wasm.Opcode = buildOpcode(.{
2238-
.valtype1 = typeToValtype(operand_ty, self.target),
2247+
.valtype1 = typeToValtype(ty, self.target),
22392248
.op = switch (op) {
22402249
.lt => .lt,
22412250
.lte => .le,
@@ -3730,3 +3739,125 @@ fn airPtrSliceFieldPtr(self: *Self, inst: Air.Inst.Index, offset: u32) InnerErro
37303739
const slice_ptr = try self.resolveInst(ty_op.operand);
37313740
return self.buildPointerOffset(slice_ptr, offset, .new);
37323741
}
3742+
3743+
fn airBinOpOverflow(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {
3744+
if (self.liveness.isUnused(inst)) return WValue{ .none = {} };
3745+
3746+
const ty_pl = self.air.instructions.items(.data)[inst].ty_pl;
3747+
const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
3748+
const lhs = try self.resolveInst(extra.lhs);
3749+
const rhs = try self.resolveInst(extra.rhs);
3750+
const lhs_ty = self.air.typeOf(extra.lhs);
3751+
3752+
// We store the bit if it's overflowed or not in this. As it's zero-initialized
3753+
// we only need to update it if an overflow (or underflow) occured.
3754+
const overflow_bit = try self.allocLocal(Type.initTag(.u1));
3755+
const int_info = lhs_ty.intInfo(self.target);
3756+
const wasm_bits = toWasmBits(int_info.bits) orelse {
3757+
return self.fail("TODO: Implement overflow arithmetic for integer bitsize: {d}", .{int_info.bits});
3758+
};
3759+
3760+
const zero = switch (wasm_bits) {
3761+
32 => WValue{ .imm32 = 0 },
3762+
64 => WValue{ .imm64 = 0 },
3763+
else => unreachable,
3764+
};
3765+
const int_max = (@as(u65, 1) << @intCast(u7, int_info.bits - @boolToInt(int_info.signedness == .signed))) - 1;
3766+
const int_max_wvalue = switch (wasm_bits) {
3767+
32 => WValue{ .imm32 = @intCast(u32, int_max) },
3768+
64 => WValue{ .imm64 = @intCast(u64, int_max) },
3769+
else => unreachable,
3770+
};
3771+
const int_min = if (int_info.signedness == .unsigned)
3772+
@as(i64, 0)
3773+
else
3774+
-@as(i64, 1) << @intCast(u6, int_info.bits - 1);
3775+
const int_min_wvalue = switch (wasm_bits) {
3776+
32 => WValue{ .imm32 = @bitCast(u32, @intCast(i32, int_min)) },
3777+
64 => WValue{ .imm64 = @bitCast(u64, int_min) },
3778+
else => unreachable,
3779+
};
3780+
3781+
if (int_info.signedness == .unsigned and op == .add) {
3782+
const diff = try self.binOp(int_max_wvalue, lhs, lhs_ty, .sub);
3783+
const cmp_res = try self.cmp(rhs, diff, lhs_ty, .gt);
3784+
try self.emitWValue(cmp_res);
3785+
try self.addLabel(.local_set, overflow_bit.local);
3786+
} else if (int_info.signedness == .unsigned and op == .sub) {
3787+
const cmp_res = try self.cmp(lhs, rhs, lhs_ty, .lt);
3788+
try self.emitWValue(cmp_res);
3789+
try self.addLabel(.local_set, overflow_bit.local);
3790+
} else if (int_info.signedness == .signed and op != .shl) {
3791+
// for overflow, we first check if lhs is > 0 (or lhs < 0 in case of subtraction). If not, we will not overflow.
3792+
// We first create an outer block, where we handle overflow.
3793+
// Then we create an inner block, where underflow is handled.
3794+
try self.startBlock(.block, wasm.block_empty);
3795+
try self.startBlock(.block, wasm.block_empty);
3796+
{
3797+
try self.emitWValue(lhs);
3798+
const cmp_result = try self.cmp(lhs, zero, lhs_ty, .lt);
3799+
try self.emitWValue(cmp_result);
3800+
}
3801+
try self.addLabel(.br_if, 0); // break to outer block, and handle underflow
3802+
3803+
// handle overflow
3804+
{
3805+
const diff = try self.binOp(int_max_wvalue, lhs, lhs_ty, .sub);
3806+
const cmp_res = try self.cmp(rhs, diff, lhs_ty, if (op == .add) .gt else .lt);
3807+
try self.emitWValue(cmp_res);
3808+
try self.addLabel(.local_set, overflow_bit.local);
3809+
}
3810+
try self.addLabel(.br, 1); // break from blocks, and continue regular flow.
3811+
try self.endBlock();
3812+
3813+
// handle underflow
3814+
{
3815+
const diff = try self.binOp(int_min_wvalue, lhs, lhs_ty, .sub);
3816+
const cmp_res = try self.cmp(rhs, diff, lhs_ty, if (op == .add) .lt else .gt);
3817+
try self.emitWValue(cmp_res);
3818+
try self.addLabel(.local_set, overflow_bit.local);
3819+
}
3820+
try self.endBlock();
3821+
}
3822+
3823+
const bin_op = if (op == .shl) blk: {
3824+
const tmp_val = try self.binOp(lhs, rhs, lhs_ty, op);
3825+
const cmp_res = try self.cmp(tmp_val, int_max_wvalue, lhs_ty, .gt);
3826+
try self.emitWValue(cmp_res);
3827+
try self.addLabel(.local_set, overflow_bit.local);
3828+
3829+
try self.emitWValue(tmp_val);
3830+
try self.emitWValue(int_max_wvalue);
3831+
switch (wasm_bits) {
3832+
32 => try self.addTag(.i32_and),
3833+
64 => try self.addTag(.i64_and),
3834+
else => unreachable,
3835+
}
3836+
try self.addLabel(.local_set, tmp_val.local);
3837+
break :blk tmp_val;
3838+
} else if (op == .mul) blk: {
3839+
const bin_op = try self.wrapBinOp(lhs, rhs, lhs_ty, op);
3840+
try self.startBlock(.block, wasm.block_empty);
3841+
// check if 0. true => Break out of block as cannot over -or underflow.
3842+
try self.emitWValue(lhs);
3843+
switch (wasm_bits) {
3844+
32 => try self.addTag(.i32_eqz),
3845+
64 => try self.addTag(.i64_eqz),
3846+
else => unreachable,
3847+
}
3848+
try self.addLabel(.br_if, 0);
3849+
const div = try self.binOp(bin_op, lhs, lhs_ty, .div);
3850+
const cmp_res = try self.cmp(div, rhs, lhs_ty, .neq);
3851+
try self.emitWValue(cmp_res);
3852+
try self.addLabel(.local_set, overflow_bit.local);
3853+
try self.endBlock();
3854+
break :blk bin_op;
3855+
} else try self.wrapBinOp(lhs, rhs, lhs_ty, op);
3856+
3857+
const result_ptr = try self.allocStack(self.air.typeOfIndex(inst));
3858+
try self.store(result_ptr, bin_op, lhs_ty, 0);
3859+
const offset = @intCast(u32, lhs_ty.abiSize(self.target));
3860+
try self.store(result_ptr, overflow_bit, Type.initTag(.u1), offset);
3861+
3862+
return result_ptr;
3863+
}

test/behavior/math.zig

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,6 @@ test "128-bit multiplication" {
639639

640640
test "@addWithOverflow" {
641641
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
642-
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
643642
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
644643
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
645644
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
@@ -661,7 +660,6 @@ test "@addWithOverflow" {
661660

662661
test "small int addition" {
663662
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
664-
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
665663
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
666664
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
667665
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
@@ -686,7 +684,6 @@ test "small int addition" {
686684

687685
test "@mulWithOverflow" {
688686
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
689-
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
690687
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
691688
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
692689
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
@@ -708,7 +705,6 @@ test "@mulWithOverflow" {
708705

709706
test "@subWithOverflow" {
710707
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
711-
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
712708
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
713709
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
714710
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
@@ -730,7 +726,6 @@ test "@subWithOverflow" {
730726

731727
test "@shlWithOverflow" {
732728
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
733-
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
734729
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
735730
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
736731
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
@@ -752,7 +747,6 @@ test "@shlWithOverflow" {
752747

753748
test "overflow arithmetic with u0 values" {
754749
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
755-
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
756750
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
757751

758752
var result: u0 = undefined;
@@ -879,7 +873,6 @@ test "quad hex float literal parsing accurate" {
879873
}
880874

881875
test "truncating shift left" {
882-
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
883876
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
884877

885878
try testShlTrunc(maxInt(u16));

0 commit comments

Comments
 (0)