Skip to content

Commit a6ae451

Browse files
committed
add @newStackCall builtin function
See #1006
1 parent 4277762 commit a6ae451

File tree

8 files changed

+298
-16
lines changed

8 files changed

+298
-16
lines changed

doc/langref.html.in

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4485,17 +4485,58 @@ mem.set(u8, dest, c);</code></pre>
44854485
If no overflow or underflow occurs, returns <code>false</code>.
44864486
</p>
44874487
{#header_close#}
4488+
{#header_open|@newStackCall#}
4489+
<pre><code class="zig">@newStackCall(new_stack: []u8, function: var, args: ...) -&gt; var</code></pre>
4490+
<p>
4491+
This calls a function, in the same way that invoking an expression with parentheses does. However,
4492+
instead of using the same stack as the caller, the function uses the stack provided in the <code>new_stack</code>
4493+
parameter.
4494+
</p>
4495+
{#code_begin|test#}
4496+
const std = @import("std");
4497+
const assert = std.debug.assert;
4498+
4499+
var new_stack_bytes: [1024]u8 = undefined;
4500+
4501+
test "calling a function with a new stack" {
4502+
const arg = 1234;
4503+
4504+
const a = @newStackCall(new_stack_bytes[0..512], targetFunction, arg);
4505+
const b = @newStackCall(new_stack_bytes[512..], targetFunction, arg);
4506+
_ = targetFunction(arg);
4507+
4508+
assert(arg == 1234);
4509+
assert(a < b);
4510+
}
4511+
4512+
fn targetFunction(x: i32) usize {
4513+
assert(x == 1234);
4514+
4515+
var local_variable: i32 = 42;
4516+
const ptr = &local_variable;
4517+
*ptr += 1;
4518+
4519+
assert(local_variable == 43);
4520+
return @ptrToInt(ptr);
4521+
}
4522+
{#code_end#}
4523+
{#header_close#}
44884524
{#header_open|@noInlineCall#}
44894525
<pre><code class="zig">@noInlineCall(function: var, args: ...) -&gt; var</code></pre>
44904526
<p>
44914527
This calls a function, in the same way that invoking an expression with parentheses does:
44924528
</p>
4493-
<pre><code class="zig">const assert = @import("std").debug.assert;
4529+
{#code_begin|test#}
4530+
const assert = @import("std").debug.assert;
4531+
44944532
test "noinline function call" {
44954533
assert(@noInlineCall(add, 3, 9) == 12);
44964534
}
44974535

4498-
fn add(a: i32, b: i32) -&gt; i32 { a + b }</code></pre>
4536+
fn add(a: i32, b: i32) i32 {
4537+
return a + b;
4538+
}
4539+
{#code_end#}
44994540
<p>
45004541
Unlike a normal function call, however, <code>@noInlineCall</code> guarantees that the call
45014542
will not be inlined. If the call must be inlined, a compile error is emitted.
@@ -6451,7 +6492,7 @@ hljs.registerLanguage("zig", function(t) {
64516492
a = t.IR + "\\s*\\(",
64526493
c = {
64536494
keyword: "const align var extern stdcallcc nakedcc volatile export pub noalias inline struct packed enum union break return try catch test continue unreachable comptime and or asm defer errdefer if else switch while for fn use bool f32 f64 void type noreturn error i8 u8 i16 u16 i32 u32 i64 u64 isize usize i8w u8w i16w i32w u32w i64w u64w isizew usizew c_short c_ushort c_int c_uint c_long c_ulong c_longlong c_ulonglong",
6454-
built_in: "atomicLoad breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchgStrong cmpxchgWeak fence divExact truncate atomicRmw sqrt field typeInfo",
6495+
built_in: "atomicLoad breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchgStrong cmpxchgWeak fence divExact truncate atomicRmw sqrt field typeInfo newStackCall",
64556496
literal: "true false null undefined"
64566497
},
64576498
n = [e, t.CLCM, t.CBCM, s, r];

src/all_types.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,7 @@ enum BuiltinFnId {
13401340
BuiltinFnIdOffsetOf,
13411341
BuiltinFnIdInlineCall,
13421342
BuiltinFnIdNoInlineCall,
1343+
BuiltinFnIdNewStackCall,
13431344
BuiltinFnIdTypeId,
13441345
BuiltinFnIdShlExact,
13451346
BuiltinFnIdShrExact,
@@ -1656,8 +1657,13 @@ struct CodeGen {
16561657
LLVMValueRef coro_alloc_helper_fn_val;
16571658
LLVMValueRef merge_err_ret_traces_fn_val;
16581659
LLVMValueRef add_error_return_trace_addr_fn_val;
1660+
LLVMValueRef stacksave_fn_val;
1661+
LLVMValueRef stackrestore_fn_val;
1662+
LLVMValueRef write_register_fn_val;
16591663
bool error_during_imports;
16601664

1665+
LLVMValueRef sp_md_node;
1666+
16611667
const char **clang_argv;
16621668
size_t clang_argv_len;
16631669
ZigList<const char *> lib_dirs;
@@ -2280,6 +2286,7 @@ struct IrInstructionCall {
22802286
bool is_async;
22812287

22822288
IrInstruction *async_allocator;
2289+
IrInstruction *new_stack;
22832290
};
22842291

22852292
struct IrInstructionConst {

src/codegen.cpp

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -938,6 +938,53 @@ static LLVMValueRef get_memcpy_fn_val(CodeGen *g) {
938938
return g->memcpy_fn_val;
939939
}
940940

941+
static LLVMValueRef get_stacksave_fn_val(CodeGen *g) {
942+
if (g->stacksave_fn_val)
943+
return g->stacksave_fn_val;
944+
945+
// declare i8* @llvm.stacksave()
946+
947+
LLVMTypeRef fn_type = LLVMFunctionType(LLVMPointerType(LLVMInt8Type(), 0), nullptr, 0, false);
948+
g->stacksave_fn_val = LLVMAddFunction(g->module, "llvm.stacksave", fn_type);
949+
assert(LLVMGetIntrinsicID(g->stacksave_fn_val));
950+
951+
return g->stacksave_fn_val;
952+
}
953+
954+
static LLVMValueRef get_stackrestore_fn_val(CodeGen *g) {
955+
if (g->stackrestore_fn_val)
956+
return g->stackrestore_fn_val;
957+
958+
// declare void @llvm.stackrestore(i8* %ptr)
959+
960+
LLVMTypeRef param_type = LLVMPointerType(LLVMInt8Type(), 0);
961+
LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), &param_type, 1, false);
962+
g->stackrestore_fn_val = LLVMAddFunction(g->module, "llvm.stackrestore", fn_type);
963+
assert(LLVMGetIntrinsicID(g->stackrestore_fn_val));
964+
965+
return g->stackrestore_fn_val;
966+
}
967+
968+
static LLVMValueRef get_write_register_fn_val(CodeGen *g) {
969+
if (g->write_register_fn_val)
970+
return g->write_register_fn_val;
971+
972+
// declare void @llvm.write_register.i64(metadata, i64 @value)
973+
// !0 = !{!"sp\00"}
974+
975+
LLVMTypeRef param_types[] = {
976+
LLVMMetadataTypeInContext(LLVMGetGlobalContext()),
977+
LLVMIntType(g->pointer_size_bytes * 8),
978+
};
979+
980+
LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 2, false);
981+
Buf *name = buf_sprintf("llvm.write_register.i%d", g->pointer_size_bytes * 8);
982+
g->write_register_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
983+
assert(LLVMGetIntrinsicID(g->write_register_fn_val));
984+
985+
return g->write_register_fn_val;
986+
}
987+
941988
static LLVMValueRef get_coro_destroy_fn_val(CodeGen *g) {
942989
if (g->coro_destroy_fn_val)
943990
return g->coro_destroy_fn_val;
@@ -2901,6 +2948,38 @@ static size_t get_async_err_code_arg_index(CodeGen *g, FnTypeId *fn_type_id) {
29012948
return 1 + get_async_allocator_arg_index(g, fn_type_id);
29022949
}
29032950

2951+
2952+
static LLVMValueRef get_new_stack_addr(CodeGen *g, LLVMValueRef new_stack) {
2953+
LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, new_stack, (unsigned)slice_ptr_index, "");
2954+
LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, new_stack, (unsigned)slice_len_index, "");
2955+
2956+
LLVMValueRef ptr_value = gen_load_untyped(g, ptr_field_ptr, 0, false, "");
2957+
LLVMValueRef len_value = gen_load_untyped(g, len_field_ptr, 0, false, "");
2958+
2959+
LLVMValueRef ptr_addr = LLVMBuildPtrToInt(g->builder, ptr_value, LLVMTypeOf(len_value), "");
2960+
LLVMValueRef end_addr = LLVMBuildNUWAdd(g->builder, ptr_addr, len_value, "");
2961+
LLVMValueRef align_amt = LLVMConstInt(LLVMTypeOf(end_addr), get_abi_alignment(g, g->builtin_types.entry_usize), false);
2962+
LLVMValueRef align_adj = LLVMBuildURem(g->builder, end_addr, align_amt, "");
2963+
return LLVMBuildNUWSub(g->builder, end_addr, align_adj, "");
2964+
}
2965+
2966+
static void gen_set_stack_pointer(CodeGen *g, LLVMValueRef aligned_end_addr) {
2967+
LLVMValueRef write_register_fn_val = get_write_register_fn_val(g);
2968+
2969+
if (g->sp_md_node == nullptr) {
2970+
Buf *sp_reg_name = buf_create_from_str(arch_stack_pointer_register_name(&g->zig_target.arch));
2971+
LLVMValueRef str_node = LLVMMDString(buf_ptr(sp_reg_name), buf_len(sp_reg_name) + 1);
2972+
g->sp_md_node = LLVMMDNode(&str_node, 1);
2973+
}
2974+
2975+
LLVMValueRef params[] = {
2976+
g->sp_md_node,
2977+
aligned_end_addr,
2978+
};
2979+
2980+
LLVMBuildCall(g->builder, write_register_fn_val, params, 2, "");
2981+
}
2982+
29042983
static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstructionCall *instruction) {
29052984
LLVMValueRef fn_val;
29062985
TypeTableEntry *fn_type;
@@ -2967,8 +3046,23 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
29673046
}
29683047

29693048
LLVMCallConv llvm_cc = get_llvm_cc(g, fn_type->data.fn.fn_type_id.cc);
2970-
LLVMValueRef result = ZigLLVMBuildCall(g->builder, fn_val,
2971-
gen_param_values, (unsigned)gen_param_index, llvm_cc, fn_inline, "");
3049+
LLVMValueRef result;
3050+
3051+
if (instruction->new_stack == nullptr) {
3052+
result = ZigLLVMBuildCall(g->builder, fn_val,
3053+
gen_param_values, (unsigned)gen_param_index, llvm_cc, fn_inline, "");
3054+
} else {
3055+
LLVMValueRef stacksave_fn_val = get_stacksave_fn_val(g);
3056+
LLVMValueRef stackrestore_fn_val = get_stackrestore_fn_val(g);
3057+
3058+
LLVMValueRef new_stack_addr = get_new_stack_addr(g, ir_llvm_value(g, instruction->new_stack));
3059+
LLVMValueRef old_stack_ref = LLVMBuildCall(g->builder, stacksave_fn_val, nullptr, 0, "");
3060+
gen_set_stack_pointer(g, new_stack_addr);
3061+
result = ZigLLVMBuildCall(g->builder, fn_val,
3062+
gen_param_values, (unsigned)gen_param_index, llvm_cc, fn_inline, "");
3063+
LLVMBuildCall(g->builder, stackrestore_fn_val, &old_stack_ref, 1, "");
3064+
}
3065+
29723066

29733067
for (size_t param_i = 0; param_i < fn_type_id->param_count; param_i += 1) {
29743068
FnGenParamInfo *gen_info = &fn_type->data.fn.gen_param_info[param_i];
@@ -6171,6 +6265,7 @@ static void define_builtin_fns(CodeGen *g) {
61716265
create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2);
61726266
create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX);
61736267
create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX);
6268+
create_builtin_fn(g, BuiltinFnIdNewStackCall, "newStackCall", SIZE_MAX);
61746269
create_builtin_fn(g, BuiltinFnIdTypeId, "typeId", 1);
61756270
create_builtin_fn(g, BuiltinFnIdShlExact, "shlExact", 2);
61766271
create_builtin_fn(g, BuiltinFnIdShrExact, "shrExact", 2);

src/ir.cpp

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,8 @@ static IrInstruction *ir_build_union_field_ptr_from(IrBuilder *irb, IrInstructio
11021102

11031103
static IrInstruction *ir_build_call(IrBuilder *irb, Scope *scope, AstNode *source_node,
11041104
FnTableEntry *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args,
1105-
bool is_comptime, FnInline fn_inline, bool is_async, IrInstruction *async_allocator)
1105+
bool is_comptime, FnInline fn_inline, bool is_async, IrInstruction *async_allocator,
1106+
IrInstruction *new_stack)
11061107
{
11071108
IrInstructionCall *call_instruction = ir_build_instruction<IrInstructionCall>(irb, scope, source_node);
11081109
call_instruction->fn_entry = fn_entry;
@@ -1113,23 +1114,27 @@ static IrInstruction *ir_build_call(IrBuilder *irb, Scope *scope, AstNode *sourc
11131114
call_instruction->arg_count = arg_count;
11141115
call_instruction->is_async = is_async;
11151116
call_instruction->async_allocator = async_allocator;
1117+
call_instruction->new_stack = new_stack;
11161118

11171119
if (fn_ref)
11181120
ir_ref_instruction(fn_ref, irb->current_basic_block);
11191121
for (size_t i = 0; i < arg_count; i += 1)
11201122
ir_ref_instruction(args[i], irb->current_basic_block);
11211123
if (async_allocator)
11221124
ir_ref_instruction(async_allocator, irb->current_basic_block);
1125+
if (new_stack != nullptr)
1126+
ir_ref_instruction(new_stack, irb->current_basic_block);
11231127

11241128
return &call_instruction->base;
11251129
}
11261130

11271131
static IrInstruction *ir_build_call_from(IrBuilder *irb, IrInstruction *old_instruction,
11281132
FnTableEntry *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args,
1129-
bool is_comptime, FnInline fn_inline, bool is_async, IrInstruction *async_allocator)
1133+
bool is_comptime, FnInline fn_inline, bool is_async, IrInstruction *async_allocator,
1134+
IrInstruction *new_stack)
11301135
{
11311136
IrInstruction *new_instruction = ir_build_call(irb, old_instruction->scope,
1132-
old_instruction->source_node, fn_entry, fn_ref, arg_count, args, is_comptime, fn_inline, is_async, async_allocator);
1137+
old_instruction->source_node, fn_entry, fn_ref, arg_count, args, is_comptime, fn_inline, is_async, async_allocator, new_stack);
11331138
ir_link_new_instruction(new_instruction, old_instruction);
11341139
return new_instruction;
11351140
}
@@ -4303,7 +4308,37 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
43034308
}
43044309
FnInline fn_inline = (builtin_fn->id == BuiltinFnIdInlineCall) ? FnInlineAlways : FnInlineNever;
43054310

4306-
IrInstruction *call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, fn_inline, false, nullptr);
4311+
IrInstruction *call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, fn_inline, false, nullptr, nullptr);
4312+
return ir_lval_wrap(irb, scope, call, lval);
4313+
}
4314+
case BuiltinFnIdNewStackCall:
4315+
{
4316+
if (node->data.fn_call_expr.params.length == 0) {
4317+
add_node_error(irb->codegen, node, buf_sprintf("expected at least 1 argument, found 0"));
4318+
return irb->codegen->invalid_instruction;
4319+
}
4320+
4321+
AstNode *new_stack_node = node->data.fn_call_expr.params.at(0);
4322+
IrInstruction *new_stack = ir_gen_node(irb, new_stack_node, scope);
4323+
if (new_stack == irb->codegen->invalid_instruction)
4324+
return new_stack;
4325+
4326+
AstNode *fn_ref_node = node->data.fn_call_expr.params.at(1);
4327+
IrInstruction *fn_ref = ir_gen_node(irb, fn_ref_node, scope);
4328+
if (fn_ref == irb->codegen->invalid_instruction)
4329+
return fn_ref;
4330+
4331+
size_t arg_count = node->data.fn_call_expr.params.length - 2;
4332+
4333+
IrInstruction **args = allocate<IrInstruction*>(arg_count);
4334+
for (size_t i = 0; i < arg_count; i += 1) {
4335+
AstNode *arg_node = node->data.fn_call_expr.params.at(i + 2);
4336+
args[i] = ir_gen_node(irb, arg_node, scope);
4337+
if (args[i] == irb->codegen->invalid_instruction)
4338+
return args[i];
4339+
}
4340+
4341+
IrInstruction *call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, FnInlineAuto, false, nullptr, new_stack);
43074342
return ir_lval_wrap(irb, scope, call, lval);
43084343
}
43094344
case BuiltinFnIdTypeId:
@@ -4513,7 +4548,7 @@ static IrInstruction *ir_gen_fn_call(IrBuilder *irb, Scope *scope, AstNode *node
45134548
}
45144549
}
45154550

4516-
IrInstruction *fn_call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, FnInlineAuto, is_async, async_allocator);
4551+
IrInstruction *fn_call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, FnInlineAuto, is_async, async_allocator, nullptr);
45174552
return ir_lval_wrap(irb, scope, fn_call, lval);
45184553
}
45194554

@@ -6825,7 +6860,7 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
68256860
IrInstruction **args = allocate<IrInstruction *>(arg_count);
68266861
args[0] = implicit_allocator_ptr; // self
68276862
args[1] = mem_slice; // old_mem
6828-
ir_build_call(irb, scope, node, nullptr, free_fn, arg_count, args, false, FnInlineAuto, false, nullptr);
6863+
ir_build_call(irb, scope, node, nullptr, free_fn, arg_count, args, false, FnInlineAuto, false, nullptr, nullptr);
68296864

68306865
IrBasicBlock *resume_block = ir_create_basic_block(irb, scope, "Resume");
68316866
ir_build_cond_br(irb, scope, node, resume_awaiter, resume_block, irb->exec->coro_suspend_block, const_bool_false);
@@ -11992,7 +12027,7 @@ static IrInstruction *ir_analyze_async_call(IrAnalyze *ira, IrInstructionCall *c
1199212027
TypeTableEntry *async_return_type = get_error_union_type(ira->codegen, alloc_fn_error_set_type, promise_type);
1199312028

1199412029
IrInstruction *result = ir_build_call(&ira->new_irb, call_instruction->base.scope, call_instruction->base.source_node,
11995-
fn_entry, fn_ref, arg_count, casted_args, false, FnInlineAuto, true, async_allocator_inst);
12030+
fn_entry, fn_ref, arg_count, casted_args, false, FnInlineAuto, true, async_allocator_inst, nullptr);
1199612031
result->value.type = async_return_type;
1199712032
return result;
1199812033
}
@@ -12362,6 +12397,19 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
1236212397
return ir_finish_anal(ira, return_type);
1236312398
}
1236412399

12400+
IrInstruction *casted_new_stack = nullptr;
12401+
if (call_instruction->new_stack != nullptr) {
12402+
TypeTableEntry *u8_ptr = get_pointer_to_type(ira->codegen, ira->codegen->builtin_types.entry_u8, false);
12403+
TypeTableEntry *u8_slice = get_slice_type(ira->codegen, u8_ptr);
12404+
IrInstruction *new_stack = call_instruction->new_stack->other;
12405+
if (type_is_invalid(new_stack->value.type))
12406+
return ira->codegen->builtin_types.entry_invalid;
12407+
12408+
casted_new_stack = ir_implicit_cast(ira, new_stack, u8_slice);
12409+
if (type_is_invalid(casted_new_stack->value.type))
12410+
return ira->codegen->builtin_types.entry_invalid;
12411+
}
12412+
1236512413
if (fn_type->data.fn.is_generic) {
1236612414
if (!fn_entry) {
1236712415
ir_add_error(ira, call_instruction->fn_ref,
@@ -12588,7 +12636,7 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
1258812636
assert(async_allocator_inst == nullptr);
1258912637
IrInstruction *new_call_instruction = ir_build_call_from(&ira->new_irb, &call_instruction->base,
1259012638
impl_fn, nullptr, impl_param_count, casted_args, false, fn_inline,
12591-
call_instruction->is_async, nullptr);
12639+
call_instruction->is_async, nullptr, casted_new_stack);
1259212640

1259312641
ir_add_alloca(ira, new_call_instruction, return_type);
1259412642

@@ -12679,7 +12727,7 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
1267912727

1268012728

1268112729
IrInstruction *new_call_instruction = ir_build_call_from(&ira->new_irb, &call_instruction->base,
12682-
fn_entry, fn_ref, call_param_count, casted_args, false, fn_inline, false, nullptr);
12730+
fn_entry, fn_ref, call_param_count, casted_args, false, fn_inline, false, nullptr, casted_new_stack);
1268312731

1268412732
ir_add_alloca(ira, new_call_instruction, return_type);
1268512733
return ir_finish_anal(ira, return_type);

0 commit comments

Comments
 (0)