Skip to content

Commit 1ef49c8

Browse files
authored
inference: parameterize some of hard-coded inference logic (#39439)
This commit parameterizes some of hard-coded inference logic: - to bail out from inference when a lattice element can't be refined or a current inference frame is proven to throw or to be dead - to add call backedges when the call return type won't be refined Those `AbstractInterpreter`s used for code optimization (including `NativeInterpreter`) usually just want the methods defined for `AbstractInterpreter`, but some other `AbstractInterpreter` may want other implementations and heuristics to control inference process. For example, [`JETInterpreter`](https://github.com/aviatesk/JET.jl) is used for code analysis and wants to add call backedges even when a call return type is `Any`.
1 parent 5d7e13f commit 1ef49c8

File tree

2 files changed

+53
-21
lines changed

2 files changed

+53
-21
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
9797
napplicable = length(applicable)
9898
rettype = Bottom
9999
edgecycle = false
100-
edges = Any[]
100+
edges = MethodInstance[]
101101
nonbot = 0 # the index of the only non-Bottom inference result if > 0
102102
seen = 0 # number of signatures actually inferred
103-
istoplevel = sv.linfo.def isa Module
104103
multiple_matches = napplicable > 1
105104

106105
if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
@@ -115,7 +114,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
115114
match = applicable[i]::MethodMatch
116115
method = match.method
117116
sig = match.spec_types
118-
if istoplevel && !isdispatchtuple(sig)
117+
if bail_out_toplevel_call(interp, sig, sv)
119118
# only infer concrete call sites in top-level expressions
120119
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
121120
rettype = Any
@@ -135,7 +134,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
135134
end
136135
edgecycle |= edgecycle1::Bool
137136
this_rt = tmerge(this_rt, rt)
138-
this_rt === Any && break
137+
if bail_out_call(interp, this_rt, sv)
138+
break
139+
end
139140
end
140141
else
141142
this_rt, edgecycle1, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv)
@@ -153,7 +154,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
153154
end
154155
seen += 1
155156
rettype = tmerge(rettype, this_rt)
156-
rettype === Any && break
157+
if bail_out_call(interp, rettype, sv)
158+
break
159+
end
157160
end
158161
# try constant propagation if only 1 method is inferred to non-Bottom
159162
# this is in preparation for inlining, or improving the return result
@@ -179,18 +182,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
179182
# and avoid keeping track of a more complex result type.
180183
rettype = Any
181184
end
182-
if !(rettype === Any) # adding a new method couldn't refine (widen) this type
183-
for edge in edges
184-
add_backedge!(edge::MethodInstance, sv)
185-
end
186-
for (thisfullmatch, mt) in zip(fullmatch, mts)
187-
if !thisfullmatch
188-
# also need an edge to the method table in case something gets
189-
# added that did not intersect with any existing method
190-
add_mt_backedge!(mt, atype, sv)
191-
end
192-
end
193-
end
185+
add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv)
194186
#print("=> ", rettype, "\n")
195187
if rettype isa LimitedAccuracy
196188
union!(sv.pclimitations, rettype.causes)
@@ -205,6 +197,27 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
205197
return CallMeta(rettype, info)
206198
end
207199

200+
function add_call_backedges!(interp::AbstractInterpreter,
201+
@nospecialize(rettype),
202+
edges::Vector{MethodInstance},
203+
fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype),
204+
sv::InferenceState)
205+
if rettype === Any
206+
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine
207+
# (widen) this type
208+
return
209+
end
210+
for edge in edges
211+
add_backedge!(edge, sv)
212+
end
213+
for (thisfullmatch, mt) in zip(fullmatch, mts)
214+
if !thisfullmatch
215+
# also need an edge to the method table in case something gets
216+
# added that did not intersect with any existing method
217+
add_mt_backedge!(mt, atype, sv)
218+
end
219+
end
220+
end
208221

209222
function const_prop_profitable(@nospecialize(arg))
210223
# have new information from argtypes that wasn't available from the signature
@@ -746,7 +759,7 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
746759
call = abstract_call(interp, nothing, ct, sv, max_methods)
747760
push!(retinfos, ApplyCallInfo(call.info, arginfo))
748761
res = tmerge(res, call.rt)
749-
if res === Any
762+
if bail_out_apply(interp, res, sv)
750763
# No point carrying forward the info, we're not gonna inline it anyway
751764
retinfo = nothing
752765
break
@@ -1171,7 +1184,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
11711184
argtypes = Vector{Any}(undef, n)
11721185
@inbounds for i = 1:n
11731186
ai = abstract_eval_value(interp, ea[i], vtypes, sv)
1174-
if ai === Bottom
1187+
if bail_out_statement(interp, ai, sv)
11751188
return Bottom
11761189
end
11771190
argtypes[i] = ai
@@ -1349,6 +1362,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13491362
condt = abstract_eval_value(interp, stmt.cond, s[pc], frame)
13501363
if condt === Bottom
13511364
empty!(frame.pclimitations)
1365+
end
1366+
if bail_out_local(interp, condt, frame)
13521367
break
13531368
end
13541369
condval = maybe_extract_const_bool(condt)
@@ -1440,7 +1455,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14401455
else
14411456
if hd === :(=)
14421457
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
1443-
t === Bottom && break
1458+
if bail_out_local(interp, t, frame)
1459+
break
1460+
end
14441461
frame.src.ssavaluetypes[pc] = t
14451462
lhs = stmt.args[1]
14461463
if isa(lhs, Slot)
@@ -1455,7 +1472,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14551472
# these do not generate code
14561473
else
14571474
t = abstract_eval_statement(interp, stmt, changes, frame)
1458-
t === Bottom && break
1475+
if bail_out_local(interp, t, frame)
1476+
break
1477+
end
14591478
if !isempty(frame.ssavalue_uses[pc])
14601479
record_ssa_assign(pc, t, frame)
14611480
else

base/compiler/types.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,16 @@ may_compress(ni::NativeInterpreter) = true
212212
may_discard_trees(ni::NativeInterpreter) = true
213213

214214
method_table(ai::AbstractInterpreter) = InternalMethodTable(get_world_counter(ai))
215+
216+
# define inference bail out logic
217+
# `NativeInterpreter` bails out from inference when
218+
# - a lattice element grows up to `Any` (inter-procedural call, abstract apply)
219+
# - a lattice element gets down to `Bottom` (statement inference, local frame inference)
220+
# - inferring non-concrete toplevel call sites
221+
bail_out_call(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
222+
bail_out_apply(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
223+
bail_out_statement(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
224+
bail_out_local(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
225+
function bail_out_toplevel_call(interp::AbstractInterpreter, @nospecialize(sig), sv)
226+
return isa(sv.linfo.def, Module) && !isdispatchtuple(sig)
227+
end

0 commit comments

Comments
 (0)