Skip to content

Commit fe216ae

Browse files
vtjnashKristofferC
authored andcommitted
Better handling for Union-type fields, particularly of singletons (#43163)
fix #43123 (cherry picked from commit d44a534)
1 parent 6b1637b commit fe216ae

File tree

4 files changed

+88
-72
lines changed

4 files changed

+88
-72
lines changed

src/cgutils.cpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,22 +2147,25 @@ static bool emit_getfield_unknownidx(jl_codectx_t &ctx,
21472147
return false;
21482148
}
21492149

2150-
static jl_cgval_t emit_unionload(jl_codectx_t &ctx, Value *addr, Value *ptindex, jl_value_t *jfty, size_t fsz, size_t al, MDNode *tbaa, bool mutabl)
2151-
{
2152-
Instruction *tindex0 = tbaa_decorate(tbaa_unionselbyte, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
2153-
//tindex0->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
2154-
// ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
2155-
// ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
2150+
static jl_cgval_t emit_unionload(jl_codectx_t &ctx, Value *addr, Value *ptindex,
2151+
jl_value_t *jfty, size_t fsz, size_t al, MDNode *tbaa, bool mutabl,
2152+
unsigned union_max, MDNode *tbaa_ptindex)
2153+
{
2154+
Instruction *tindex0 = tbaa_decorate(tbaa_ptindex, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
2155+
tindex0->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
2156+
ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
2157+
ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
21562158
Value *tindex = ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tindex0);
2157-
if (mutabl) {
2159+
if (fsz > 0 && mutabl) {
21582160
// move value to an immutable stack slot (excluding tindex)
2159-
Type *ET = IntegerType::get(jl_LLVMContext, 8 * al);
2160-
AllocaInst *lv = emit_static_alloca(ctx, ET);
2161-
lv->setOperand(0, ConstantInt::get(T_int32, (fsz + al - 1) / al));
2161+
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (fsz + al - 1) / al);
2162+
AllocaInst *lv = emit_static_alloca(ctx, AT);
2163+
if (al > 1)
2164+
lv->setAlignment(Align(al));
21622165
emit_memcpy(ctx, lv, tbaa, addr, tbaa, fsz, al);
21632166
addr = lv;
21642167
}
2165-
return mark_julia_slot(addr, jfty, tindex, tbaa);
2168+
return mark_julia_slot(fsz > 0 ? addr : nullptr, jfty, tindex, tbaa);
21662169
}
21672170

21682171
// If `nullcheck` is not NULL and a pointer NULL check is necessary
@@ -2236,7 +2239,8 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
22362239
}
22372240
else if (jl_is_uniontype(jfty)) {
22382241
size_t fsz = 0, al = 0;
2239-
bool isptr = !jl_islayout_inline(jfty, &fsz, &al);
2242+
int union_max = jl_islayout_inline(jfty, &fsz, &al);
2243+
bool isptr = (union_max == 0);
22402244
assert(!isptr && fsz == jl_field_size(jt, idx) - 1); (void)isptr;
22412245
Value *ptindex;
22422246
if (isboxed) {
@@ -2246,7 +2250,7 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
22462250
else {
22472251
ptindex = emit_struct_gep(ctx, cast<StructType>(lt), staddr, byte_offset + fsz);
22482252
}
2249-
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, jt->name->mutabl);
2253+
return emit_unionload(ctx, addr, ptindex, jfty, fsz, al, tbaa, jt->name->mutabl, union_max, tbaa_unionselbyte);
22502254
}
22512255
assert(jl_is_concrete_type(jfty));
22522256
if (!jt->name->mutabl && !(maybe_null && (jfty == (jl_value_t*)jl_bool_type ||
@@ -3298,7 +3302,8 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
32983302
jl_value_t *jfty = jl_field_type(sty, idx0);
32993303
if (!jl_field_isptr(sty, idx0) && jl_is_uniontype(jfty)) {
33003304
size_t fsz = 0, al = 0;
3301-
bool isptr = !jl_islayout_inline(jfty, &fsz, &al);
3305+
int union_max = jl_islayout_inline(jfty, &fsz, &al);
3306+
bool isptr = (union_max == 0);
33023307
assert(!isptr && fsz == jl_field_size(sty, idx0) - 1); (void)isptr;
33033308
// compute tindex from rhs
33043309
jl_cgval_t rhs_union = convert_julia_type(ctx, rhs, jfty);
@@ -3310,9 +3315,9 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
33103315
BasicBlock *BB = ctx.builder.GetInsertBlock();
33113316
jl_cgval_t oldval = rhs;
33123317
if (!issetfield)
3313-
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true);
3314-
Value *Success;
3315-
BasicBlock *DoneBB;
3318+
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true, union_max, tbaa_unionselbyte);
3319+
Value *Success = NULL;
3320+
BasicBlock *DoneBB = NULL;
33163321
if (isreplacefield || ismodifyfield) {
33173322
if (ismodifyfield) {
33183323
if (needlock)
@@ -3329,13 +3334,13 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
33293334
emit_typecheck(ctx, rhs, jfty, fname);
33303335
rhs = update_julia_type(ctx, rhs, jfty);
33313336
}
3332-
rhs_union = convert_julia_type(ctx, rhs, jfty);
3337+
rhs_union = convert_julia_type(ctx, rhs, jfty);
33333338
if (rhs_union.typ == jl_bottom_type)
33343339
return jl_cgval_t();
33353340
if (needlock)
33363341
emit_lockstate_value(ctx, strct, true);
33373342
cmp = oldval;
3338-
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true);
3343+
oldval = emit_unionload(ctx, addr, ptindex, jfty, fsz, al, strct.tbaa, true, union_max, tbaa_unionselbyte);
33393344
}
33403345
BasicBlock *XchgBB = BasicBlock::Create(jl_LLVMContext, "xchg", ctx.f);
33413346
DoneBB = BasicBlock::Create(jl_LLVMContext, "done_xchg", ctx.f);

src/codegen.cpp

Lines changed: 45 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,25 +2375,18 @@ static Value *emit_box_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, const
23752375
Value *nullcheck1, Value *nullcheck2)
23762376
{
23772377
if (jl_pointer_egal(arg1.typ) || jl_pointer_egal(arg2.typ)) {
2378-
assert((arg1.isboxed || arg1.constant) && (arg2.isboxed || arg2.constant) &&
2379-
"Expected unboxed cases to be handled earlier");
2380-
Value *varg1 = arg1.constant ? literal_pointer_val(ctx, arg1.constant) : arg1.V;
2381-
Value *varg2 = arg2.constant ? literal_pointer_val(ctx, arg2.constant) : arg2.V;
2382-
varg1 = maybe_decay_tracked(ctx, varg1);
2383-
varg2 = maybe_decay_tracked(ctx, varg2);
2384-
if (cast<PointerType>(varg1->getType())->getAddressSpace() != cast<PointerType>(varg2->getType())->getAddressSpace()) {
2385-
varg1 = decay_derived(ctx, varg1);
2386-
varg2 = decay_derived(ctx, varg2);
2387-
}
2388-
return ctx.builder.CreateICmpEQ(emit_bitcast(ctx, varg1, T_pint8),
2389-
emit_bitcast(ctx, varg2, T_pint8));
2378+
// if we can be certain we won't try to load from the pointer (because
2379+
// we know boxed is trivial), we can skip the separate null checks
2380+
// and just do the ICmpEQ test
2381+
if (!arg1.TIndex && !arg2.TIndex)
2382+
nullcheck1 = nullcheck2 = nullptr;
23902383
}
2391-
23922384
return emit_nullcheck_guard2(ctx, nullcheck1, nullcheck2, [&] {
2393-
Value *varg1 = arg1.constant ? literal_pointer_val(ctx, arg1.constant) : maybe_bitcast(ctx, value_to_pointer(ctx, arg1).V, T_pjlvalue);
2394-
Value *varg2 = arg2.constant ? literal_pointer_val(ctx, arg2.constant) : maybe_bitcast(ctx, value_to_pointer(ctx, arg2).V, T_pjlvalue);
2395-
varg1 = decay_derived(ctx, varg1);
2396-
varg2 = decay_derived(ctx, varg2);
2385+
Value *varg1 = decay_derived(ctx, boxed(ctx, arg1));
2386+
Value *varg2 = decay_derived(ctx, boxed(ctx, arg2));
2387+
if (jl_pointer_egal(arg1.typ) || jl_pointer_egal(arg2.typ)) {
2388+
return ctx.builder.CreateICmpEQ(varg1, varg2);
2389+
}
23972390
Value *neq = ctx.builder.CreateICmpNE(varg1, varg2);
23982391
return emit_guarded_test(ctx, neq, true, [&] {
23992392
Value *dtarg = emit_typeof_boxed(ctx, arg1);
@@ -2938,28 +2931,28 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
29382931
*ret = ghostValue(ety);
29392932
}
29402933
else if (!isboxed && jl_is_uniontype(ety)) {
2941-
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
2942-
Value *data = emit_bitcast(ctx, emit_arrayptr(ctx, ary, ary_ex), AT->getPointerTo());
2943-
// isbits union selector bytes are stored after a->maxsize
2944-
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
2945-
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
2934+
Value *data = emit_arrayptr(ctx, ary, ary_ex);
29462935
Value *offset = emit_arrayoffset(ctx, ary, nd);
2947-
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
2948-
Value *selidx_m = emit_arraylen(ctx, ary);
2949-
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
2950-
Value *ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
2936+
Value *ptindex;
2937+
if (elsz == 0) {
2938+
ptindex = data;
2939+
}
2940+
else {
2941+
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
2942+
data = emit_bitcast(ctx, data, AT->getPointerTo());
2943+
// isbits union selector bytes are stored after a->maxsize
2944+
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
2945+
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
2946+
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
2947+
Value *selidx_m = emit_arraylen(ctx, ary);
2948+
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
2949+
ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
2950+
data = ctx.builder.CreateInBoundsGEP(AT, data, idx);
2951+
}
29512952
ptindex = emit_bitcast(ctx, ptindex, T_pint8);
29522953
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, offset);
29532954
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, idx);
2954-
Instruction *tindex = tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateAlignedLoad(T_int8, ptindex, Align(1)));
2955-
tindex->setMetadata(LLVMContext::MD_range, MDNode::get(jl_LLVMContext, {
2956-
ConstantAsMetadata::get(ConstantInt::get(T_int8, 0)),
2957-
ConstantAsMetadata::get(ConstantInt::get(T_int8, union_max)) }));
2958-
AllocaInst *lv = emit_static_alloca(ctx, AT);
2959-
if (al > 1)
2960-
lv->setAlignment(Align(al));
2961-
emit_memcpy(ctx, lv, tbaa_arraybuf, ctx.builder.CreateInBoundsGEP(AT, data, idx), tbaa_arraybuf, elsz, al, false);
2962-
*ret = mark_julia_slot(lv, ety, ctx.builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), tindex), tbaa_arraybuf);
2955+
*ret = emit_unionload(ctx, data, ptindex, ety, elsz, al, tbaa_arraybuf, true, union_max, tbaa_arrayselbyte);
29632956
}
29642957
else {
29652958
MDNode *aliasscope = (f == jl_builtin_const_arrayref) ? ctx.aliasscope : nullptr;
@@ -3045,28 +3038,31 @@ static bool emit_builtin_call(jl_codectx_t &ctx, jl_cgval_t *ret, jl_value_t *f,
30453038
if (!isboxed && jl_is_uniontype(ety)) {
30463039
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * al), (elsz + al - 1) / al);
30473040
Value *data = emit_bitcast(ctx, emit_arrayptr(ctx, ary, ary_ex), AT->getPointerTo());
3041+
Value *offset = emit_arrayoffset(ctx, ary, nd);
30483042
// compute tindex from val
30493043
jl_cgval_t rhs_union = convert_julia_type(ctx, val, ety);
30503044
Value *tindex = compute_tindex_unboxed(ctx, rhs_union, ety);
30513045
tindex = ctx.builder.CreateNUWSub(tindex, ConstantInt::get(T_int8, 1));
3052-
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
3053-
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
3054-
Value *offset = emit_arrayoffset(ctx, ary, nd);
3055-
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
3056-
Value *selidx_m = emit_arraylen(ctx, ary);
3057-
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
3058-
Value *ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
3046+
Value *ptindex;
3047+
if (elsz == 0) {
3048+
ptindex = data;
3049+
}
3050+
else {
3051+
Value *ndims = (nd == -1 ? emit_arrayndims(ctx, ary) : ConstantInt::get(T_int16, nd));
3052+
Value *is_vector = ctx.builder.CreateICmpEQ(ndims, ConstantInt::get(T_int16, 1));
3053+
Value *selidx_v = ctx.builder.CreateSub(emit_vectormaxsize(ctx, ary), ctx.builder.CreateZExt(offset, T_size));
3054+
Value *selidx_m = emit_arraylen(ctx, ary);
3055+
Value *selidx = ctx.builder.CreateSelect(is_vector, selidx_v, selidx_m);
3056+
ptindex = ctx.builder.CreateInBoundsGEP(AT, data, selidx);
3057+
data = ctx.builder.CreateInBoundsGEP(AT, data, idx);
3058+
}
30593059
ptindex = emit_bitcast(ctx, ptindex, T_pint8);
30603060
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, offset);
30613061
ptindex = ctx.builder.CreateInBoundsGEP(T_int8, ptindex, idx);
30623062
tbaa_decorate(tbaa_arrayselbyte, ctx.builder.CreateStore(tindex, ptindex));
3063-
if (jl_is_datatype(val.typ) && jl_datatype_size(val.typ) == 0) {
3064-
// no-op
3065-
}
3066-
else {
3067-
// copy data
3068-
Value *addr = ctx.builder.CreateInBoundsGEP(AT, data, idx);
3069-
emit_unionmove(ctx, addr, tbaa_arraybuf, val, nullptr);
3063+
if (elsz > 0 && (!jl_is_datatype(val.typ) || jl_datatype_size(val.typ) > 0)) {
3064+
// copy data (if any)
3065+
emit_unionmove(ctx, data, tbaa_arraybuf, val, nullptr);
30703066
}
30713067
}
30723068
else {

src/rtutils.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,12 +1001,14 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
10011001
n += jl_printf(out, ")}[");
10021002
size_t j, tlen = jl_array_len(v);
10031003
jl_array_t *av = (jl_array_t*)v;
1004-
jl_datatype_t *el_type = (jl_datatype_t*)jl_tparam0(vt);
1004+
jl_value_t *el_type = jl_tparam0(vt);
1005+
char *typetagdata = (!av->flags.ptrarray && jl_is_uniontype(el_type)) ? jl_array_typetagdata(av) : NULL;
10051006
int nlsep = 0;
10061007
if (av->flags.ptrarray) {
10071008
// print arrays with newlines, unless the elements are probably small
10081009
for (j = 0; j < tlen; j++) {
1009-
jl_value_t *p = jl_array_ptr_ref(av, j);
1010+
jl_value_t **ptr = ((jl_value_t**)av->data) + j;
1011+
jl_value_t *p = *ptr;
10101012
if (p != NULL && (uintptr_t)p >= 4096U) {
10111013
jl_value_t *p_ty = jl_typeof(p);
10121014
if ((uintptr_t)p_ty >= 4096U) {
@@ -1022,11 +1024,14 @@ static size_t jl_static_show_x_(JL_STREAM *out, jl_value_t *v, jl_datatype_t *vt
10221024
n += jl_printf(out, "\n ");
10231025
for (j = 0; j < tlen; j++) {
10241026
if (av->flags.ptrarray) {
1025-
n += jl_static_show_x(out, jl_array_ptr_ref(v, j), depth);
1027+
jl_value_t **ptr = ((jl_value_t**)av->data) + j;
1028+
n += jl_static_show_x(out, *ptr, depth);
10261029
}
10271030
else {
10281031
char *ptr = ((char*)av->data) + j * av->elsize;
1029-
n += jl_static_show_x_(out, (jl_value_t*)ptr, el_type, depth);
1032+
n += jl_static_show_x_(out, (jl_value_t*)ptr,
1033+
typetagdata ? (jl_datatype_t*)jl_nth_union_component(el_type, typetagdata[j]) : (jl_datatype_t*)el_type,
1034+
depth);
10301035
}
10311036
if (j != tlen - 1)
10321037
n += jl_printf(out, nlsep ? ",\n " : ", ");

test/compiler/codegen.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,3 +658,13 @@ function f42645()
658658
res
659659
end
660660
@test ((f42645()::B42645).y::A42645{Int}).x
661+
662+
# issue #43123
663+
@noinline cmp43123(a::Some, b::Some) = something(a) === something(b)
664+
@noinline cmp43123(a, b) = a[] === b[]
665+
@test cmp43123(Some{Function}(+), Some{Union{typeof(+), typeof(-)}}(+))
666+
@test !cmp43123(Some{Function}(+), Some{Union{typeof(+), typeof(-)}}(-))
667+
@test cmp43123(Ref{Function}(+), Ref{Union{typeof(+), typeof(-)}}(+))
668+
@test !cmp43123(Ref{Function}(+), Ref{Union{typeof(+), typeof(-)}}(-))
669+
@test cmp43123(Function[+], Union{typeof(+), typeof(-)}[+])
670+
@test !cmp43123(Function[+], Union{typeof(+), typeof(-)}[-])

0 commit comments

Comments
 (0)