Skip to content

Commit dd57605

Browse files
committed
optimizer: fix #42754, inline union-split const-prop'ed sources
This commit complements #39754 and #39305: implements a logic to use constant-prop'ed results for inlining at union-split callsite. Currently it works only for cases when constant-prop' succeeded for all (union-split) signatures. > example ```julia julia> mutable struct X # NOTE in order to confuse `fieldtype_tfunc`, we need to have at least two fields with different types a::Union{Nothing, Int} b::Symbol end; julia> code_typed((X, Union{Nothing,Int})) do x, a # this `setproperty` call would be union-split and constant-prop will happen for # each signature: inlining would fail if we don't use constant-prop'ed source # since the approximated inlining cost of `convert(fieldtype(X, sym), a)` would # end up very high if we don't propagate `sym::Const(:a)` x.a = a x end |> only |> first ``` > before this commit ```julia 1 ─ %1 = Base.setproperty!::typeof(setproperty!) │ %2 = (isa)(a, Nothing)::Bool └── goto #3 if not %2 2 ─ %4 = π (a, Nothing) │ invoke %1(_2::X, 🅰️:Symbol, %4::Nothing)::Any └── goto #6 3 ─ %7 = (isa)(a, Int64)::Bool └── goto #5 if not %7 4 ─ %9 = π (a, Int64) │ invoke %1(_2::X, 🅰️:Symbol, %9::Int64)::Any └── goto #6 5 ─ Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{} └── unreachable 6 ┄ return x ) ``` > after this commit ```julia CodeInfo( 1 ─ %1 = (isa)(a, Nothing)::Bool └── goto #3 if not %1 2 ─ Base.setfield!(x, :a, nothing)::Nothing └── goto #6 3 ─ %5 = (isa)(a, Int64)::Bool └── goto #5 if not %5 4 ─ %7 = π (a, Int64) │ Base.setfield!(x, :a, %7)::Int64 └── goto #6 5 ─ Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{} └── unreachable 6 ┄ return x ) ```
1 parent 59aa3ed commit dd57605

File tree

3 files changed

+92
-66
lines changed

3 files changed

+92
-66
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method:
685685
end
686686

687687
function const_prop_function_heuristic(
688-
interp::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo,
688+
_::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo,
689689
nargs::Int, all_overridden::Bool, _::InferenceState)
690690
if nargs > 1
691691
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
@@ -704,9 +704,9 @@ function const_prop_function_heuristic(
704704
end
705705
end
706706
if !all_overridden && (istopfunction(f, :+) || istopfunction(f, :-) || istopfunction(f, :*) ||
707-
istopfunction(f, :(==)) || istopfunction(f, :!=) ||
708-
istopfunction(f, :<=) || istopfunction(f, :>=) || istopfunction(f, :<) || istopfunction(f, :>) ||
709-
istopfunction(f, :<<) || istopfunction(f, :>>))
707+
istopfunction(f, :(==)) || istopfunction(f, :!=) ||
708+
istopfunction(f, :<=) || istopfunction(f, :>=) || istopfunction(f, :<) || istopfunction(f, :>) ||
709+
istopfunction(f, :<<) || istopfunction(f, :>>))
710710
# it is almost useless to inline the op when all the same type,
711711
# but highly worthwhile to inline promote of a constant
712712
length(argtypes) > 2 || return false

base/compiler/ssair/inlining.jl

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ end
3535
struct DelayedInliningSpec
3636
match::Union{MethodMatch, InferenceResult}
3737
atypes::Vector{Any}
38-
stmttype::Any
3938
end
4039

4140
struct InliningTodo
@@ -44,11 +43,11 @@ struct InliningTodo
4443
spec::Union{ResolvedInliningSpec, DelayedInliningSpec}
4544
end
4645

47-
InliningTodo(mi::MethodInstance, match::MethodMatch,
48-
atypes::Vector{Any}, @nospecialize(stmttype)) = InliningTodo(mi, DelayedInliningSpec(match, atypes, stmttype))
46+
InliningTodo(mi::MethodInstance, match::MethodMatch, atypes::Vector{Any}) =
47+
InliningTodo(mi, DelayedInliningSpec(match, atypes))
4948

50-
InliningTodo(result::InferenceResult, atypes::Vector{Any}, @nospecialize(stmttype)) =
51-
InliningTodo(result.linfo, DelayedInliningSpec(result, atypes, stmttype))
49+
InliningTodo(result::InferenceResult, atypes::Vector{Any}) =
50+
InliningTodo(result.linfo, DelayedInliningSpec(result, atypes))
5251

5352
struct ConstantCase
5453
val::Any
@@ -677,7 +676,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
677676
handled = false
678677
if isa(info, ConstCallInfo)
679678
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
680-
ir, state1.id, new_stmt, info, new_sig,call.rt, istate, flag, false, todo)
679+
ir, state1.id, new_stmt, info, new_sig, istate, flag, false, todo)
681680
handled = true
682681
else
683682
info = info.call
@@ -687,8 +686,9 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
687686
info = isa(info, MethodMatchInfo) ?
688687
MethodMatchInfo[info] : info.matches
689688
# See if we can inline this call to `iterate`
690-
analyze_single_call!(ir, todo, state1.id, new_stmt,
691-
new_sig, call.rt, info, istate, flag)
689+
analyze_single_call!(
690+
ir, todo, state1.id, new_stmt,
691+
new_sig, info, istate, flag)
692692
end
693693
if i != length(thisarginfo.each)
694694
valT = getfield_tfunc(call.rt, Const(1))
@@ -708,11 +708,13 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
708708
return new_argexprs, new_atypes
709709
end
710710

711-
function rewrite_invoke_exprargs!(argexprs::Vector{Any})
711+
function rewrite_invoke_exprargs!(expr::Expr)
712+
argexprs = expr.args
712713
argexpr0 = argexprs[2]
713714
argexprs = argexprs[4:end]
714715
pushfirst!(argexprs, argexpr0)
715-
return argexprs
716+
expr.args = argexprs
717+
return expr
716718
end
717719

718720
function compileable_specialization(et::Union{EdgeTracker, Nothing}, match::MethodMatch)
@@ -791,7 +793,7 @@ function validate_sparams(sparams::SimpleVector)
791793
end
792794

793795
function analyze_method!(match::MethodMatch, atypes::Vector{Any},
794-
state::InliningState, @nospecialize(stmttyp), flag::UInt8)
796+
state::InliningState, flag::UInt8)
795797
method = match.method
796798
methsig = method.sig
797799

@@ -821,7 +823,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
821823
return compileable_specialization(et, match)
822824
end
823825

824-
todo = InliningTodo(mi, match, atypes, stmttyp)
826+
todo = InliningTodo(mi, match, atypes)
825827
# If we don't have caches here, delay resolving this MethodInstance
826828
# until the batch inlining step (or an external post-processing pass)
827829
state.mi_cache === nothing && return todo
@@ -846,17 +848,13 @@ function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(cas
846848
if isa(case, ConstantCase)
847849
ir[SSAValue(idx)] = case.val
848850
elseif isa(case, MethodInstance)
849-
if isinvoke
850-
stmt.args = rewrite_invoke_exprargs!(stmt.args)
851-
end
851+
isinvoke && rewrite_invoke_exprargs!(stmt)
852852
stmt.head = :invoke
853853
pushfirst!(stmt.args, case)
854854
elseif case === nothing
855855
# Do, well, nothing
856856
else
857-
if isinvoke
858-
stmt.args = rewrite_invoke_exprargs!(stmt.args)
859-
end
857+
isinvoke && rewrite_invoke_exprargs!(stmt)
860858
push!(todo, idx=>(case::InliningTodo))
861859
end
862860
nothing
@@ -1005,7 +1003,6 @@ is_builtin(s::Signature) =
10051003
function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result)::InvokeCallInfo,
10061004
state::InliningState, todo::Vector{Pair{Int, Any}}, flag::UInt8)
10071005
stmt = ir.stmts[idx][:inst]
1008-
calltype = ir.stmts[idx][:type]
10091006

10101007
if !match.fully_covers
10111008
# TODO: We could union split out the signature check and continue on
@@ -1018,7 +1015,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result):
10181015
pushfirst!(atypes, atype0)
10191016

10201017
if isa(result, InferenceResult) && !is_stmt_noinline(flag)
1021-
(; mi) = item = InliningTodo(result, atypes, calltype)
1018+
(; mi) = item = InliningTodo(result, atypes)
10221019
validate_sparams(mi.sparam_vals) || return nothing
10231020
if argtypes_to_type(atypes) <: mi.def.sig
10241021
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
@@ -1027,7 +1024,7 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result):
10271024
end
10281025
end
10291026

1030-
result = analyze_method!(match, atypes, state, calltype, flag)
1027+
result = analyze_method!(match, atypes, state, flag)
10311028
handle_single_case!(ir, stmt, idx, result, true, todo)
10321029
return nothing
10331030
end
@@ -1136,13 +1133,12 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta
11361133
return sig
11371134
end
11381135

1139-
function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
1140-
sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo},
1141-
state::InliningState, flag::UInt8)
1136+
function analyze_single_call!(
1137+
ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
1138+
sig::Signature, infos::Vector{MethodMatchInfo}, state::InliningState, flag::UInt8)
11421139
cases = Pair{Any, Any}[]
1143-
signature_union = Union{}
1144-
only_method = nothing # keep track of whether there is one matching method
1145-
too_many = false
1140+
local signature_union = Bottom
1141+
local only_method = nothing # keep track of whether there is one matching method
11461142
local meth
11471143
local fully_covered = true
11481144
for i in 1:length(infos)
@@ -1151,8 +1147,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
11511147
if meth.ambig
11521148
# Too many applicable methods
11531149
# Or there is a (partial?) ambiguity
1154-
too_many = true
1155-
break
1150+
return
11561151
elseif length(meth) == 0
11571152
# No applicable methods; try next union split
11581153
continue
@@ -1172,7 +1167,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
11721167
fully_covered = false
11731168
continue
11741169
end
1175-
case = analyze_method!(match, sig.atypes, state, calltype, flag)
1170+
case = analyze_method!(match, sig.atypes, state, flag)
11761171
if case === nothing
11771172
fully_covered = false
11781173
continue
@@ -1183,8 +1178,6 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
11831178
end
11841179
end
11851180

1186-
too_many && return
1187-
11881181
signature_fully_covered = sig.atype <: signature_union
11891182
# If we're fully covered and there's only one applicable method,
11901183
# we inline, even if the signature is not a dispatch tuple
@@ -1199,7 +1192,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
11991192
match = meth[1]
12001193
end
12011194
fully_covered = true
1202-
case = analyze_method!(match, sig.atypes, state, calltype, flag)
1195+
case = analyze_method!(match, sig.atypes, state, flag)
12031196
case === nothing && return
12041197
push!(cases, Pair{Any,Any}(match.spec_types, case))
12051198
end
@@ -1219,34 +1212,41 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
12191212
return nothing
12201213
end
12211214

1215+
# try to create `InliningTodo`s using constant-prop'ed results
1216+
# currently it works only when constant-prop' succeeded for all (union-split) signatures
1217+
# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
12221218
function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
1223-
info::ConstCallInfo, sig::Signature, @nospecialize(calltype),
1219+
info::ConstCallInfo, sig::Signature,
12241220
state::InliningState, flag::UInt8,
12251221
isinvoke::Bool, todo::Vector{Pair{Int, Any}})
1226-
# when multiple matches are found, bail out and later inliner will union-split this signature
1227-
# TODO effectively use multiple constant analysis results here
1228-
length(info.results) == 1 || return false
1229-
result = info.results[1]
1230-
isa(result, InferenceResult) || return false
1231-
1232-
(; mi) = item = InliningTodo(result, sig.atypes, calltype)
1233-
validate_sparams(mi.sparam_vals) || return true
1234-
mthd_sig = mi.def.sig
1235-
mistypes = mi.specTypes
1236-
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1237-
if sig.atype <: mthd_sig
1238-
handle_single_case!(ir, stmt, idx, item, isinvoke, todo)
1239-
return true
1240-
else
1241-
item === nothing && return true
1242-
# Union split out the error case
1243-
item = UnionSplit(false, sig.atype, Pair{Any, Any}[mistypes => item])
1244-
if isinvoke
1245-
stmt.args = rewrite_invoke_exprargs!(stmt.args)
1222+
sigtype = sig.atype
1223+
cases = Pair{Any, Any}[] # TODO avoid this allocation for single cases ?
1224+
local fully_covered = true
1225+
local signature_union = Bottom
1226+
for result in info.results
1227+
isa(result, InferenceResult) || return false
1228+
(; mi) = item = InliningTodo(result, sig.atypes)
1229+
if !validate_sparams(mi.sparam_vals)
1230+
fully_covered = false
1231+
continue
1232+
end
1233+
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1234+
if item === nothing
1235+
fully_covered = false
1236+
continue
12461237
end
1238+
push!(cases, Pair{Any,Any}(mi.specTypes, item))
1239+
signature_union = Union{signature_union, mi.def.sig}
1240+
end
1241+
fully_covered &= sigtype <: signature_union
1242+
if fully_covered && length(cases) == 1
1243+
handle_single_case!(ir, stmt, idx, cases[1][2], isinvoke, todo)
1244+
elseif length(cases) > 0
1245+
item = UnionSplit(fully_covered, sigtype, cases)
1246+
isinvoke && rewrite_invoke_exprargs!(stmt)
12471247
push!(todo, idx=>item)
1248-
return true
12491248
end
1249+
return true
12501250
end
12511251

12521252
function assemble_inline_todo!(ir::IRCode, state::InliningState)
@@ -1258,11 +1258,11 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
12581258
sig === nothing && continue
12591259

12601260
stmt = ir.stmts[idx][:inst]
1261-
calltype = ir.stmts[idx][:type]
12621261
info = ir.stmts[idx][:info]
12631262

12641263
# Check whether this call was @pure and evaluates to a constant
12651264
if info isa MethodResultPure
1265+
calltype = ir.stmts[idx][:type]
12661266
if calltype isa Const && is_inlineable_constant(calltype.val)
12671267
ir.stmts[idx][:inst] = quoted(calltype.val)
12681268
continue
@@ -1278,20 +1278,19 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
12781278
continue
12791279
end
12801280

1281-
# If inference arrived at this result by using constant propagation,
1282-
# it'll have performed a specialized analysis for just this case. Use its
1283-
# result.
1281+
# if inference arrived here with constant-prop'ed result(s),
1282+
# we can perform a specialized analysis for just this case
12841283
if isa(info, ConstCallInfo)
12851284
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
1286-
ir, idx, stmt, info, sig, calltype, state, flag, sig.f === Core.invoke, todo)
1285+
ir, idx, stmt, info, sig, state, flag, sig.f === Core.invoke, todo)
12871286
continue
12881287
else
12891288
info = info.call
12901289
end
12911290
end
12921291

12931292
if isa(info, OpaqueClosureCallInfo)
1294-
result = analyze_method!(info.match, sig.atypes, state, calltype, flag)
1293+
result = analyze_method!(info.match, sig.atypes, state, flag)
12951294
handle_single_case!(ir, stmt, idx, result, false, todo)
12961295
continue
12971296
end
@@ -1313,7 +1312,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13131312
continue
13141313
end
13151314

1316-
analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state, flag)
1315+
analyze_single_call!(ir, todo, idx, stmt, sig, infos, state, flag)
13171316
end
13181317
todo
13191318
end

test/compiler/inline.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,3 +680,30 @@ let f(x) = (x...,)
680680
# the the original apply call is not union-split, but the inserted `iterate` call is.
681681
@test code_typed(f, Tuple{Union{Int64, CartesianIndex{1}, CartesianIndex{3}}})[1][2] == Tuple{Int64}
682682
end
683+
684+
let # https://github.com/JuliaLang/julia/issues/42754
685+
# inline union-split constant-prop'ed sources
686+
code = @eval Module() begin
687+
mutable struct X
688+
# NOTE in order to confuse `fieldtype_tfunc`, we need to have at least two fields with different types
689+
a::Union{Nothing, Int}
690+
b::Symbol
691+
end
692+
$code_typed1((X, Union{Nothing,Int})) do x, a
693+
# this `setproperty` call would be union-split and constant-prop will happen for
694+
# each signature: inlining would fail if we don't use constant-prop'ed source
695+
# since the approximate inlining cost of `convert(fieldtype(X, sym), a)` would
696+
# end up very high if we don't propagate `sym::Const(:a)`
697+
x.a = a
698+
x
699+
end
700+
end
701+
@test all(code) do @nospecialize(x)
702+
isinvoke(x, :setproperty!) && return false
703+
if Meta.isexpr(x, :call)
704+
f = x.args[1]
705+
isa(f, GlobalRef) && f.name === :setproperty! && return false
706+
end
707+
return true
708+
end
709+
end

0 commit comments

Comments
 (0)