Skip to content

Commit 0a2b2dc

Browse files
committed
inference: parameterize some of hard-coded inference logic
This commit parameterizes some of hard-coded inference logic: - to bail out from inference when a lattice element can't be refined or a current inference frame is proven to throw or to be dead - to add call backedges when the call return type won't be refined Those `AbstractInterpreter`s used for code optimization (including `NativeInterpreter`) usually just want the methods defined for `AbstractInterpreter`, but some other `AbstractInterpreter` may want other implementations and heuristics to control inference process. For example, [`JETInterpreter`](https://github.com/aviatesk/JET.jl) is used for code analysis and wants to add call backedges even when a call return type is `Any`.
1 parent 5cd1e3e commit 0a2b2dc

File tree

2 files changed

+51
-22
lines changed

2 files changed

+51
-22
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
9797
napplicable = length(applicable)
9898
rettype = Bottom
9999
edgecycle = false
100-
edges = Any[]
100+
edges = MethodInstance[]
101101
nonbot = 0 # the index of the only non-Bottom inference result if > 0
102102
seen = 0 # number of signatures actually inferred
103-
istoplevel = sv.linfo.def isa Module
104103
multiple_matches = napplicable > 1
105104

106105
if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
@@ -115,7 +114,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
115114
match = applicable[i]::MethodMatch
116115
method = match.method
117116
sig = match.spec_types
118-
if istoplevel && !isdispatchtuple(sig)
117+
if bail_out_toplevel_call(interp, sig, sv)
119118
# only infer concrete call sites in top-level expressions
120119
add_remark!(interp, sv, "Refusing to infer non-concrete call site in top-level expression")
121120
rettype = Any
@@ -135,7 +134,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
135134
end
136135
edgecycle |= edgecycle1::Bool
137136
this_rt = tmerge(this_rt, rt)
138-
this_rt === Any && break
137+
if bail_out_call(interp, this_rt, sv)
138+
break
139+
end
139140
end
140141
else
141142
this_rt, edgecycle1, edge = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv)
@@ -153,7 +154,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
153154
end
154155
seen += 1
155156
rettype = tmerge(rettype, this_rt)
156-
rettype === Any && break
157+
if bail_out_call(interp, rettype, sv)
158+
break
159+
end
157160
end
158161
# try constant propagation if only 1 method is inferred to non-Bottom
159162
# this is in preparation for inlining, or improving the return result
@@ -179,18 +182,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
179182
# and avoid keeping track of a more complex result type.
180183
rettype = Any
181184
end
182-
if !(rettype === Any) # adding a new method couldn't refine (widen) this type
183-
for edge in edges
184-
add_backedge!(edge::MethodInstance, sv)
185-
end
186-
for (thisfullmatch, mt) in zip(fullmatch, mts)
187-
if !thisfullmatch
188-
# also need an edge to the method table in case something gets
189-
# added that did not intersect with any existing method
190-
add_mt_backedge!(mt, atype, sv)
191-
end
192-
end
193-
end
185+
add_call_backedge!(interp, rettype, edges, fullmatch, mts, atype, sv)
194186
#print("=> ", rettype, "\n")
195187
if rettype isa LimitedAccuracy
196188
union!(sv.pclimitations, rettype.causes)
@@ -205,6 +197,26 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
205197
return CallMeta(rettype, info)
206198
end
207199

200+
function add_call_backedge!(interp::AbstractInterpreter,
201+
@nospecialize(rettype),
202+
edges::Vector{MethodInstance},
203+
fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype),
204+
sv::InferenceState)
205+
if rettype === Any
206+
# for `NativeInterpreter`, adding a new method couldn't refine (widen) this type
207+
return
208+
end
209+
for edge in edges
210+
add_backedge!(edge, sv)
211+
end
212+
for (thisfullmatch, mt) in zip(fullmatch, mts)
213+
if !thisfullmatch
214+
# also need an edge to the method table in case something gets
215+
# added that did not intersect with any existing method
216+
add_mt_backedge!(mt, atype, sv)
217+
end
218+
end
219+
end
208220

209221
function const_prop_profitable(@nospecialize(arg))
210222
# have new information from argtypes that wasn't available from the signature
@@ -733,7 +745,7 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
733745
call = abstract_call(interp, nothing, ct, sv, max_methods)
734746
push!(retinfos, ApplyCallInfo(call.info, arginfo))
735747
res = tmerge(res, call.rt)
736-
if res === Any
748+
if bail_out_apply(interp, res, sv)
737749
# No point carrying forward the info, we're not gonna inline it anyway
738750
retinfo = nothing
739751
break
@@ -1158,7 +1170,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
11581170
argtypes = Vector{Any}(undef, n)
11591171
@inbounds for i = 1:n
11601172
ai = abstract_eval_value(interp, ea[i], vtypes, sv)
1161-
if ai === Bottom
1173+
if bail_out_statement(interp, ai, sv)
11621174
return Bottom
11631175
end
11641176
argtypes[i] = ai
@@ -1334,7 +1346,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13341346
pc´ = (stmt::GotoNode).label
13351347
elseif isa(stmt, GotoIfNot)
13361348
condt = abstract_eval_value(interp, stmt.cond, s[pc], frame)
1337-
if condt === Bottom
1349+
if bail_out_local(interp, condt, frame)
13381350
empty!(frame.pclimitations)
13391351
break
13401352
end
@@ -1427,7 +1439,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14271439
else
14281440
if hd === :(=)
14291441
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
1430-
t === Bottom && break
1442+
if bail_out_local(interp, t, frame)
1443+
break
1444+
end
14311445
frame.src.ssavaluetypes[pc] = t
14321446
lhs = stmt.args[1]
14331447
if isa(lhs, Slot)
@@ -1442,7 +1456,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14421456
# these do not generate code
14431457
else
14441458
t = abstract_eval_statement(interp, stmt, changes, frame)
1445-
t === Bottom && break
1459+
if bail_out_local(interp, t, frame)
1460+
break
1461+
end
14461462
if !isempty(frame.ssavalue_uses[pc])
14471463
record_ssa_assign(pc, t, frame)
14481464
else

base/compiler/types.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,16 @@ may_compress(ni::NativeInterpreter) = true
212212
may_discard_trees(ni::NativeInterpreter) = true
213213

214214
method_table(ai::AbstractInterpreter) = InternalMethodTable(get_world_counter(ai))
215+
216+
# define inference bail out logic
217+
# `NativeInterpreter` bails out from inference when
218+
# - a lattice element grows up to `Any` (inter-procedural call, abstract apply)
219+
# - a lattice element gets down to `Bottom` (statement inference, local frame inference)
220+
# - inferring non-concrete toplevel call sites
221+
bail_out_call(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
222+
bail_out_apply(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Any
223+
bail_out_statement(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
224+
bail_out_local(interp::AbstractInterpreter, @nospecialize(t), sv) = t === Bottom
225+
function bail_out_toplevel_call(interp::AbstractInterpreter, @nospecialize(sig), sv)
226+
return isa(sv.linfo.def, Module) && !isdispatchtuple(sig)
227+
end

0 commit comments

Comments
 (0)