Skip to content

Better handling for Union-type fields, particularly of singletons #43163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2146,22 +2146,25 @@ static bool emit_getfield_unknownidx(jl_codectx_t &ctx,
return false;
}

static jl_cgval_t emit_unionload(jl_codectx_t &ctx, Value *addr, Value *ptindex, jl_value_t *jfty, size_t fsz, size_t al, MDNode *tbaa, bool mutabl)
{
Instruction *tindex0 = tbaa_decorate(tbaa_unionselbyte, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
//tindex0->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
// ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
// ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
static jl_cgval_t emit_unionload(jl_codectx_t &ctx, Value *addr, Value *ptindex,
jl_value_t *jfty, size_t fsz, size_t al, MDNode *tbaa, bool mutabl,
unsigned union_max, MDNode *tbaa_ptindex)
{
Instruction *tindex0 = tbaa_decorate(tbaa_ptindex, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
tindex0->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
Value *tindex = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tindex0);
if (mutabl) {
if (fsz > 0 && mutabl) {
// move value to an immutable stack slot (excluding tindex)
Type *ET = IntegerType::get(jl_LLVMContext, 8 * al);
AllocaInst *lv = emit_static_alloca(ctx, ET);
lv->setOperand(0, ConstantInt::get(T_int32, (fsz + al - 1) / al));
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (fsz + al - 1) / al);
AllocaInst *lv = emit_static_alloca(ctx, AT);
if (al > 1)
lv->setAlignment(Align(al));
emit_memcpy(ctx, lv, tbaa, addr, tbaa, fsz, al);
addr = lv;
}
return mark_julia_slot(addr, jfty, tindex, tbaa);
return mark_julia_slot(fsz > 0 ? addr : nullptr, jfty, tindex, tbaa);
}

// If `nullcheck` is not NULL and a pointer NULL check is necessary
Expand Down Expand Up @@ -2235,7 +2238,8 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
}
else if (jl_is_uniontype(jfty)) {
size_t fsz = 0, al = 0;
bool isptr = !jl_islayout_inline(jfty, &fsz, &al);
int union_max = jl_islayout_inline(jfty, &fsz, &al);
bool isptr = (union_max == 0);
assert(!isptr && fsz == jl_field_size(jt, idx) - 1); (void)isptr;
Value *ptindex;
if (isboxed) {
Expand All @@ -2245,7 +2249,7 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
else {
ptindex = emit_struct_gep(ctx, cast<StructType>(lt), staddr, byte_offset + fsz);
}
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, jt->name->mutabl);
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, jt->name->mutabl, union_max, tbaa_unionselbyte);
}
assert(jl_is_concrete_type(jfty));
if (!jt->name->mutabl && !(maybe_null && (jfty == (jl_value_t*)jl_bool_type ||
Expand Down Expand Up @@ -3306,7 +3310,8 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
jl_value_t *jfty = jl_field_type(sty, idx0);
if (!jl_field_isptr(sty, idx0) && jl_is_uniontype(jfty)) {
size_t fsz = 0, al = 0;
bool isptr = !jl_islayout_inline(jfty, &fsz, &al);
int union_max = jl_islayout_inline(jfty, &fsz, &al);
bool isptr = (union_max == 0);
assert(!isptr && fsz == jl_field_size(sty, idx0) - 1); (void)isptr;
// compute tindex from rhs
jl_cgval_t rhs_union = convert_julia_type(ctx, rhs, jfty);
Expand All @@ -3323,7 +3328,7 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
}
jl_cgval_t oldval = rhs;
if (!issetfield)
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true);
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true, union_max, tbaa_unionselbyte);
Value *Success = NULL;
BasicBlock *DoneBB = NULL;
if (isreplacefield || ismodifyfield) {
Expand All @@ -3342,13 +3347,13 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
emit_typecheck(ctx, rhs, jfty, fname);
rhs = update_julia_type(ctx, rhs, jfty);
}
rhs_union = convert_julia_type(ctx, rhs, jfty);
rhs_union = convert_julia_type(ctx, rhs, jfty);
if (rhs_union.typ == jl_bottom_type)
return jl_cgval_t();
if (needlock)
emit_lockstate_value(ctx, strct, true);
cmp = oldval;
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true);
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true, union_max, tbaa_unionselbyte);
}
BasicBlock *XchgBB = BasicBlock::Create(jl_LLVMContext, "xchg", ctx.f);
DoneBB = BasicBlock::Create(jl_LLVMContext, "done_xchg", ctx.f);
Expand Down
104 changes: 46 additions & 58 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2178,34 +2178,19 @@ static jl_cgval_t emit_globalref(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *
static Value *emit_box_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const jl_cgval_t &arg2,
Value *nullcheck1, Value *nullcheck2)
{
// If either sides is boxed or can be trivially boxed,
// we'll prefer to do a pointer check.
// At this point, we know that at least one of the arguments isn't a constant
// so a runtime content check will involve at least one load from the
// pointer (and likely a type check)
// so a pointer comparison should be no worse than that even in imaging mode
// when the constant pointer has to be loaded.
// Note that we ignore nullcheck, since in the case where it may be set, we
// also knew the types of both fields must be the same so there cannot be
// any unboxed values on either side.
if ((!arg1.TIndex && jl_pointer_egal(arg1.typ)) || (!arg2.TIndex && jl_pointer_egal(arg2.typ))) {
// n.b. Vboxed may be incomplete if Tindex is set (missing singletons)
// and Vboxed == isboxed || Tindex
if ((arg1.Vboxed || arg1.constant) && (arg2.Vboxed || arg2.constant)) {
Value *varg1 = arg1.constant ? literal_pointer_val(ctx, arg1.constant) : maybe_bitcast(ctx, arg1.Vboxed, T_pjlvalue);
Value *varg2 = arg2.constant ? literal_pointer_val(ctx, arg2.constant) : maybe_bitcast(ctx, arg2.Vboxed, T_pjlvalue);
return ctx.builder.CreateICmpEQ(decay_derived(ctx, varg1), decay_derived(ctx, varg2));
}
return ConstantInt::get(T_int1, 0); // seems probably unreachable?
// (since intersection of rt1 and rt2 is non-empty here, so we should have
// a value in this intersection, but perhaps intersection might have failed)
if (jl_pointer_egal(arg1.typ) || jl_pointer_egal(arg2.typ)) {
// if we can be certain we won't try to load from the pointer (because
// we know boxed is trivial), we can skip the separate null checks
// and just do the ICmpEQ test
if (!arg1.TIndex && !arg2.TIndex)
nullcheck1 = nullcheck2 = nullptr;
}

return emit_nullcheck_guard2(ctx, nullcheck1, nullcheck2, [&] {
Value *varg1 = arg1.constant ? literal_pointer_val(ctx, arg1.constant) : maybe_bitcast(ctx, value_to_pointer(ctx, arg1).V, T_pjlvalue);
Value *varg2 = arg2.constant ? literal_pointer_val(ctx, arg2.constant) : maybe_bitcast(ctx, value_to_pointer(ctx, arg2).V, T_pjlvalue);
varg1 = decay_derived(ctx, varg1);
varg2 = decay_derived(ctx, varg2);
Value *varg1 = decay_derived(ctx, boxed(ctx, arg1));
Value *varg2 = decay_derived(ctx, boxed(ctx, arg2));
if (jl_pointer_egal(arg1.typ) || jl_pointer_egal(arg2.typ)) {
return ctx.builder.CreateICmpEQ(varg1, varg2);
}
Value *neq = ctx.builder.CreateICmpNE(varg1, varg2);
return emit_guarded_test(ctx, neq, true, [&] {
Value *dtarg = emit_typeof_boxed(ctx, arg1);
Expand Down Expand Up @@ -2731,28 +2716,28 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
*ret = ghostValue(ety);
}
else if (!isboxed && jl_is_uniontype(ety)) {
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
Value *data = emit_bitcast(ctx, emit_arrayptr(ctx, ary, ary_ex), AT->getPointerTo());
// isbits union selector bytes are stored after a->maxsize
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
Value *data = emit_arrayptr(ctx, ary, ary_ex);
Value *offset = emit_arrayoffset(ctx, ary, nd);
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
Value *selidx_m = emit_arraylen(ctx, ary);
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
Value *ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
Value *ptindex;
if (elsz == 0) {
ptindex = data;
}
else {
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
data = emit_bitcast(ctx, data, AT->getPointerTo());
// isbits union selector bytes are stored after a->maxsize
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
Value *selidx_m = emit_arraylen(ctx, ary);
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
data = ctx.builder.CreateInBoundsGEP(AT, data, idx);
}
ptindex = emit_bitcast(ctx, ptindex, T_pint8);
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, offset);
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, idx);
Instruction *tindex = tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
tindex->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
AllocaInst *lv = emit_static_alloca(ctx, AT);
if (al > 1)
lv->setAlignment(Align(al));
emit_memcpy(ctx, lv, tbaa_arraybuf, ctx.builder.CreateInBoundsGEP(AT, data, idx), tbaa_arraybuf, elsz, al, false);
*ret = mark_julia_slot(lv, ety, ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tindex), tbaa_arraybuf);
*ret = emit_unionload(ctx, data, ptindex, ety, elsz, al, tbaa_arraybuf, true, union_max, tbaa_arrayselbyte);
}
else {
MDNode *aliasscope = (f == jl_builtin_const_arrayref) ? ctx.aliasscope : nullptr;
Expand Down Expand Up @@ -2838,28 +2823,31 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
if (!isboxed && jl_is_uniontype(ety)) {
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
Value *data = emit_bitcast(ctx, emit_arrayptr(ctx, ary, ary_ex), AT->getPointerTo());
Value *offset = emit_arrayoffset(ctx, ary, nd);
// compute tindex from val
jl_cgval_t rhs_union = convert_julia_type(ctx, val, ety);
Value *tindex = compute_tindex_unboxed(ctx, rhs_union, ety);
tindex = ctx.builder.CreateNUWSub(tindex, ConstantInt::get(T_int8, 1));
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
Value *offset = emit_arrayoffset(ctx, ary, nd);
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
Value *selidx_m = emit_arraylen(ctx, ary);
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
Value *ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
Value *ptindex;
if (elsz == 0) {
ptindex = data;
}
else {
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
Value *selidx_m = emit_arraylen(ctx, ary);
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
data = ctx.builder.CreateInBoundsGEP(AT, data, idx);
}
ptindex = emit_bitcast(ctx, ptindex, T_pint8);
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, offset);
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, idx);
tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateStore(tindex, ptindex));
if (jl_is_datatype(val.typ) && jl_datatype_size(val.typ) == 0) {
// no-op
}
else {
// copy data
Value *addr = ctx.builder.CreateInBoundsGEP(AT, data, idx);
emit_unionmove(ctx, addr, tbaa_arraybuf, val, nullptr);
if (elsz > 0 && (!jl_is_datatype(val.typ) || jl_datatype_size(val.typ) > 0)) {
// copy data (if any)
emit_unionmove(ctx, data, tbaa_arraybuf, val, nullptr);
}
}
else {
Expand Down
13 changes: 9 additions & 4 deletions src/rtutils.c
Original file line number Diff line number Diff line change
Expand Up @@ -1014,12 +1014,14 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
n += jl_printf(out, ")}[");
size_t j, tlen = jl_array_len(v);
jl_array_t *av = (jl_array_t*)v;
jl_datatype_t *el_type = (jl_datatype_t*)jl_tparam0(vt);
jl_value_t *el_type = jl_tparam0(vt);
char *typetagdata = (!av->flags.ptrarray && jl_is_uniontype(el_type)) ? jl_array_typetagdata(av) : NULL;
int nlsep = 0;
if (av->flags.ptrarray) {
// print arrays with newlines, unless the elements are probably small
for (j = 0; j < tlen; j++) {
jl_value_t *p = jl_array_ptr_ref(av, j);
jl_value_t **ptr = ((jl_value_t**)av->data) + j;
jl_value_t *p = *ptr;
if (p != NULL && (uintptr_t)p >= 4096U) {
jl_value_t *p_ty = jl_typeof(p);
if ((uintptr_t)p_ty >= 4096U) {
Expand All @@ -1035,11 +1037,14 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
n += jl_printf(out, "\n ");
for (j = 0; j < tlen; j++) {
if (av->flags.ptrarray) {
n += jl_static_show_x(out, jl_array_ptr_ref(v, j), depth);
jl_value_t **ptr = ((jl_value_t**)av->data) + j;
n += jl_static_show_x(out, *ptr, depth);
}
else {
char *ptr = ((char*)av->data) + j * av->elsize;
n += jl_static_show_x_(out, (jl_value_t*)ptr, el_type, depth);
n += jl_static_show_x_(out, (jl_value_t*)ptr,
typetagdata ? (jl_datatype_t*)jl_nth_union_component(el_type, typetagdata[j]) : (jl_datatype_t*)el_type,
depth);
}
if (j != tlen - 1)
n += jl_printf(out, nlsep ? ",\n " : ", ");
Expand Down
10 changes: 10 additions & 0 deletions test/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,13 @@ function f42645()
res
end
@test ((f42645()::B42645).y::A42645{Int}).x

# issue #43123
@noinline cmp43123(a::Some, b::Some) = something(a) === something(b)
@noinline cmp43123(a, b) = a[] === b[]
@test cmp43123(Some{Function}(+), Some{Union{typeof(+), typeof(-)}}(+))
@test !cmp43123(Some{Function}(+), Some{Union{typeof(+), typeof(-)}}(-))
@test cmp43123(Ref{Function}(+), Ref{Union{typeof(+), typeof(-)}}(+))
@test !cmp43123(Ref{Function}(+), Ref{Union{typeof(+), typeof(-)}}(-))
@test cmp43123(Function[+], Union{typeof(+), typeof(-)}[+])
@test !cmp43123(Function[+], Union{typeof(+), typeof(-)}[-])