Skip to content

Commit 9557259

Browse files
authored
inference: forward Conditional inter-procedurally (#42529)
The PR #38905 only "back-propagates" conditional constraint (from callee to caller), but currently we don't "forward" it (caller to callee), and so inter-procedural constraint propagation won't happen for e.g.: ```julia ifelselike(cnd, x, y) = cnd ? x : y @test Base.return_types((Any,Int,)) do x, y ifelselike(isa(x, Int), x, y) end |> only == Int ``` This commit complements #38905 and enables further inter-procedural conditional constraint propagation by forwarding `Conditional` to callees when it imposes a constraint on any other argument, during constant propagation. We also improve constant-prop' heuristics in these ways: - remove `const_prop_rettype_heuristic` since it handles rare cases, where const-prop' doens't seem to be worthwhile, e.g. it won't be so useful to try to propagate `Const(Tuple{DataType,DataType})` for `Const(convert)(::Const(Tuple{DataType,DataType}), ::Tuple{DataType,DataType} -> Tuple{DataType,DataType}` - rename `is_allconst` to `is_all_overridden` - also minor refactors and improvements added
1 parent 76c2431 commit 9557259

File tree

6 files changed

+201
-79
lines changed

6 files changed

+201
-79
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 86 additions & 61 deletions
Large diffs are not rendered by default.

base/compiler/inferenceresult.jl

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,57 @@
33
function is_argtype_match(@nospecialize(given_argtype),
44
@nospecialize(cache_argtype),
55
overridden_by_const::Bool)
6-
if isa(given_argtype, Const) || isa(given_argtype, PartialStruct) || isa(given_argtype, PartialOpaque)
6+
if is_forwardable_argtype(given_argtype)
77
return is_lattice_equal(given_argtype, cache_argtype)
88
end
99
return !overridden_by_const
1010
end
1111

12+
function is_forwardable_argtype(@nospecialize x)
13+
return isa(x, Const) ||
14+
isa(x, Conditional) ||
15+
isa(x, PartialStruct) ||
16+
isa(x, PartialOpaque)
17+
end
18+
1219
# In theory, there could be a `cache` containing a matching `InferenceResult`
1320
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
1421
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
1522
# so that we can construct cache-correct `InferenceResult`s in the first place.
16-
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override::Bool)
23+
function matching_cache_argtypes(
24+
linfo::MethodInstance, (; fargs, argtypes)::ArgInfo, va_override::Bool)
1725
@assert isa(linfo.def, Method) # ensure the next line works
1826
nargs::Int = linfo.def.nargs
19-
given_argtypes = anymap(widenconditional, given_argtypes)
27+
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override)
28+
given_argtypes = Vector{Any}(undef, length(argtypes))
29+
local condargs = nothing
30+
for i in 1:length(argtypes)
31+
argtype = argtypes[i]
32+
# forward `Conditional` if it conveys a constraint on any other argument
33+
if isa(argtype, Conditional) && fargs !== nothing
34+
cnd = argtype
35+
slotid = find_constrained_arg(cnd, fargs)
36+
if slotid !== nothing
37+
# using union-split signature, we may be able to narrow down `Conditional`
38+
sigt = widenconst(slotid > nargs ? argtypes[slotid] : cache_argtypes[slotid])
39+
vtype = tmeet(cnd.vtype, sigt)
40+
elsetype = tmeet(cnd.elsetype, sigt)
41+
if vtype === Bottom && elsetype === Bottom
42+
# we accidentally proved this method match is impossible
43+
# TODO bail out here immediately rather than just propagating Bottom ?
44+
given_argtypes[i] = Bottom
45+
else
46+
if condargs === nothing
47+
condargs = Tuple{Int,Int}[]
48+
end
49+
push!(condargs, (slotid, i))
50+
given_argtypes[i] = Conditional(SlotNumber(slotid), vtype, elsetype)
51+
end
52+
continue
53+
end
54+
end
55+
given_argtypes[i] = widenconditional(argtype)
56+
end
2057
isva = va_override || linfo.def.isva
2158
if isva || isvarargtype(given_argtypes[end])
2259
isva_given_argtypes = Vector{Any}(undef, nargs)
@@ -30,15 +67,22 @@ function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector,
3067
last = nargs
3168
end
3269
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
70+
# invalidate `Conditional` imposed on varargs
71+
if condargs !== nothing
72+
for (slotid, i) in condargs
73+
if slotid last
74+
isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i])
75+
end
76+
end
77+
end
3378
end
3479
given_argtypes = isva_given_argtypes
3580
end
3681
@assert length(given_argtypes) == nargs
37-
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override)
3882
for i in 1:nargs
3983
given_argtype = given_argtypes[i]
4084
cache_argtype = cache_argtypes[i]
41-
if !is_argtype_match(given_argtype, cache_argtype, overridden_by_const[i])
85+
if !is_argtype_match(given_argtype, cache_argtype, false)
4286
# prefer the argtype we were given over the one computed from `linfo`
4387
cache_argtypes[i] = given_argtype
4488
overridden_by_const[i] = true

base/compiler/tfuncs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any
964964
v = unwrapva(argtypes[5])
965965
TF = getfield_tfunc(o, f)
966966
push!(sv.ssavalue_uses[sv.currpc], sv.currpc) # temporarily disable `call_result_unused` check for this call
967-
callinfo = abstract_call(interp, nothing, Any[op, TF, v], sv, #=max_methods=# 1)
967+
callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), sv, #=max_methods=# 1)
968968
pop!(sv.ssavalue_uses[sv.currpc], sv.currpc)
969969
TF2 = tmeet(callinfo.rt, widenconst(TF))
970970
if TF2 === Bottom
@@ -1747,7 +1747,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
17471747
if contains_is(argtypes_vec, Union{})
17481748
return CallMeta(Const(Union{}), false)
17491749
end
1750-
call = abstract_call(interp, nothing, argtypes_vec, sv, -1)
1750+
call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), sv, -1)
17511751
info = verbose_stmt_info(interp) ? ReturnTypeCallInfo(call.info) : false
17521752
rt = widenconditional(call.rt)
17531753
if isa(rt, Const)

base/compiler/types.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ If `interp` is an `AbstractInterpreter`, it is expected to provide at least the
1717
"""
1818
abstract type AbstractInterpreter end
1919

20+
struct ArgInfo
21+
fargs::Union{Nothing,Vector{Any}}
22+
argtypes::Vector{Any}
23+
end
24+
2025
"""
2126
InferenceResult
2227
@@ -29,8 +34,10 @@ mutable struct InferenceResult
2934
result # ::Type, or InferenceState if WIP
3035
src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available
3136
valid_worlds::WorldRange # if inference and optimization is finished
32-
function InferenceResult(linfo::MethodInstance, given_argtypes = nothing, va_override=false)
33-
argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes, va_override)
37+
function InferenceResult(linfo::MethodInstance,
38+
arginfo::Union{Nothing,ArgInfo} = nothing,
39+
va_override::Bool = false)
40+
argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo, va_override)
3441
return new(linfo, argtypes, overridden_by_const, Any, nothing, WorldRange())
3542
end
3643
end

base/compiler/typeutils.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,15 +259,6 @@ unioncomplexity(u::UnionAll) = max(unioncomplexity(u.body)::Int, unioncomplexity
259259
unioncomplexity(t::TypeofVararg) = isdefined(t, :T) ? unioncomplexity(t.T)::Int : 0
260260
unioncomplexity(@nospecialize(x)) = 0
261261

262-
function improvable_via_constant_propagation(@nospecialize(t))
263-
if isconcretetype(t) && t <: Tuple
264-
for p in t.parameters
265-
p === DataType && return true
266-
end
267-
end
268-
return false
269-
end
270-
271262
# convert a Union of Tuple types to a Tuple of Unions
272263
function unswitchtupleunion(u::Union)
273264
ts = uniontypes(u)

test/compiler/inference.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,6 +2000,61 @@ function _g_ifelse_isa_()
20002000
end
20012001
@test Base.return_types(_g_ifelse_isa_, ()) == [Int]
20022002

2003+
@testset "Conditional forwarding" begin
2004+
# forward `Conditional` if it conveys a constraint on any other argument
2005+
ifelselike(cnd, x, y) = cnd ? x : y
2006+
2007+
@test Base.return_types((Any,Int,)) do x, y
2008+
ifelselike(isa(x, Int), x, y)
2009+
end |> only == Int
2010+
2011+
# should work nicely with union-split
2012+
@test Base.return_types((Union{Int,Nothing},)) do x
2013+
ifelselike(isa(x, Int), x, 0)
2014+
end |> only == Int
2015+
2016+
@test Base.return_types((Any,Int)) do x, y
2017+
ifelselike(!isa(x, Int), y, x)
2018+
end |> only == Int
2019+
2020+
@test Base.return_types((Any,Int)) do x, y
2021+
a = ifelselike(x === 0, x, 0) # ::Const(0)
2022+
if a == 0
2023+
return y
2024+
else
2025+
return nothing # dead branch
2026+
end
2027+
end |> only == Int
2028+
2029+
# pick up the first if there are multiple constrained arguments
2030+
@test Base.return_types((Any,)) do x
2031+
ifelselike(isa(x, Int), x, x)
2032+
end |> only == Any
2033+
2034+
# just propagate multiple constraints
2035+
ifelselike2(cnd1, cnd2, x, y, z) = cnd1 ? x : cnd2 ? y : z
2036+
@test Base.return_types((Any,Any)) do x, y
2037+
ifelselike2(isa(x, Int), isa(y, Int), x, y, 0)
2038+
end |> only == Int
2039+
2040+
# work with `invoke`
2041+
@test Base.return_types((Any,Any)) do x, y
2042+
Base.@invoke ifelselike(isa(x, Int), x, y::Int)
2043+
end |> only == Int
2044+
2045+
# don't be confused with vararg method
2046+
vacond(cnd, va...) = cnd ? va : 0
2047+
@test Base.return_types((Any,)) do x
2048+
# at runtime we will see `va::Tuple{Tuple{Int,Int}, Tuple{Int,Int}}`
2049+
vacond(isa(x, Tuple{Int,Int}), x, x)
2050+
end |> only == Union{Int,Tuple{Any,Any}}
2051+
2052+
# demonstrate extra constraint propagation for Base.ifelse
2053+
@test Base.return_types((Any,Int,)) do x, y
2054+
ifelse(isa(x, Int), x, y)
2055+
end |> only == Int
2056+
end
2057+
20032058
# Equivalence of Const(T.instance) and T for singleton types
20042059
@test Const(nothing) Nothing && Nothing Const(nothing)
20052060

0 commit comments

Comments
 (0)