Skip to content

Commit edd2866

Browse files
committed
inference: forward Conditional inter-procedurally
The PR #38905 only "back-propagates" conditional constraint (from callee to caller), but currently we don't "forward" it (caller to callee), and so inter-procedural constraint propagation won't happen for e.g.: ```julia ifelselike(cnd, x, y) = cnd ? x : y @test Base.return_types((Any,Int,)) do x, y ifelselike(isa(x, Int), x, y) end |> only == Int ``` This commit complements #38905 and enables further inter-procedural conditional constraint propagation by forwarding `Conditional` to callees when it imposes a constraint on any other argument, during constant propagation.
1 parent c8cc1b5 commit edd2866

File tree

5 files changed

+163
-50
lines changed

5 files changed

+163
-50
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 73 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function is_improvable(@nospecialize(rtype))
2929
end
3030

3131
function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
32-
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype),
32+
(; fargs, argtypes)::ArgInfo, @nospecialize(atype),
3333
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
3434
if sv.params.unoptimize_throw_blocks && is_stmt_throw_block(get_curr_ssaflag(sv))
3535
add_remark!(interp, sv, "Skipped call in throw block")
@@ -85,7 +85,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
8585
push!(edges, edge)
8686
end
8787
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
88-
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
88+
arginfo = ArgInfo(fargs, this_argtypes)
89+
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
8990
if const_result !== nothing
9091
const_rt, const_result = const_result
9192
if const_rt !== rt && const_rt rt
@@ -110,7 +111,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
110111
# try constant propagation with argtypes for this match
111112
# this is in preparation for inlining, or improving the return result
112113
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
113-
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
114+
arginfo = ArgInfo(fargs, this_argtypes)
115+
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
114116
if const_result !== nothing
115117
const_this_rt, const_result = const_result
116118
if const_this_rt !== this_rt && const_this_rt this_rt
@@ -523,13 +525,13 @@ struct MethodCallResult
523525
end
524526

525527
function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult,
526-
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
528+
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
527529
sv::InferenceState, va_override::Bool)
528-
mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv)
530+
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
529531
mi === nothing && return nothing
530532
# try constant prop'
531533
inf_cache = get_inference_cache(interp)
532-
inf_result = cache_lookup(mi, argtypes, inf_cache)
534+
inf_result = cache_lookup(mi, arginfo.argtypes, inf_cache)
533535
if inf_result === nothing
534536
# if there might be a cycle, check to make sure we don't end up
535537
# calling ourselves here.
@@ -545,7 +547,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
545547
return nothing
546548
end
547549
end
548-
inf_result = InferenceResult(mi, argtypes, va_override)
550+
inf_result = InferenceResult(mi; arginfo, va_override)
549551
if !any(inf_result.overridden_by_const)
550552
add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes")
551553
return nothing
@@ -565,7 +567,7 @@ end
565567
# if there's a possibility we could get a better result (hopefully without doing too much work)
566568
# returns `MethodInstance` with constant arguments, returns nothing otherwise
567569
function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::MethodCallResult,
568-
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
570+
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
569571
sv::InferenceState)
570572
if !InferenceParams(interp).ipo_constant_propagation
571573
add_remark!(interp, sv, "[constprop] Disabled by parameter")
@@ -580,14 +582,14 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
580582
force || const_prop_entry_heuristic(interp, result, sv) || return nothing
581583
nargs::Int = method.nargs
582584
method.isva && (nargs -= 1)
583-
length(argtypes) < nargs && return nothing
584-
if !(const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, result.rt))
585+
length(arginfo.argtypes) < nargs && return nothing
586+
if !(const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, result.rt))
585587
add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics")
586588
return nothing
587589
end
588-
allconst = is_allconst(argtypes)
590+
allconst = is_allconst(arginfo)
589591
if !force
590-
if !const_prop_function_heuristic(interp, f, argtypes, nargs, allconst)
592+
if !const_prop_function_heuristic(interp, f, arginfo, nargs, allconst)
591593
add_remark!(interp, sv, "[constprop] Disabled by function heuristic")
592594
return nothing
593595
end
@@ -599,7 +601,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
599601
return nothing
600602
end
601603
mi = mi::MethodInstance
602-
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, argtypes, sv)
604+
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
603605
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
604606
return nothing
605607
end
@@ -617,8 +619,11 @@ function const_prop_entry_heuristic(interp::AbstractInterpreter, result::MethodC
617619
end
618620

619621
# see if propagating constants may be worthwhile
620-
function const_prop_argument_heuristic(interp::AbstractInterpreter, argtypes::Vector{Any})
622+
function const_prop_argument_heuristic(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo)
621623
for a in argtypes
624+
if isa(a, Conditional) && fargs !== nothing
625+
return is_const_prop_profitable_conditional(a, fargs)
626+
end
622627
a = widenconditional(a)
623628
if has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a)
624629
return true
@@ -642,13 +647,34 @@ function is_const_prop_profitable_arg(@nospecialize(arg))
642647
return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val))
643648
end
644649

650+
function is_const_prop_profitable_conditional(cnd::Conditional, fargs::Vector{Any})
651+
slotid = find_constrained_arg(cnd, fargs)
652+
if slotid !== nothing
653+
return true
654+
end
655+
return is_const_prop_profitable_arg(widenconditional(cnd))
656+
end
657+
658+
function find_constrained_arg(cnd::Conditional, fargs::Vector{Any})
659+
slot = cnd.var
660+
return findfirst(fargs) do @nospecialize(x)
661+
x === slot
662+
end
663+
end
664+
645665
function const_prop_rettype_heuristic(interp::AbstractInterpreter, @nospecialize(rettype))
646666
return improvable_via_constant_propagation(rettype)
647667
end
648668

649-
function is_allconst(argtypes::Vector{Any})
669+
function is_allconst((; fargs, argtypes)::ArgInfo)
650670
for a in argtypes
671+
if isa(a, Conditional) && fargs !== nothing
672+
if is_const_prop_profitable_conditional(a, fargs)
673+
continue
674+
end
675+
end
651676
a = widenconditional(a)
677+
# TODO unify these condition with `has_nontrivial_const_info`
652678
if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque)
653679
return false
654680
end
@@ -663,7 +689,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method:
663689
istopfunction(f, :setproperty!)
664690
end
665691

666-
function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, nargs::Int, allconst::Bool)
692+
function const_prop_function_heuristic(
693+
interp::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo,
694+
nargs::Int, allconst::Bool)
667695
if nargs > 1
668696
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
669697
arrty = argtypes[2]
@@ -705,7 +733,7 @@ end
705733
# result anyway.
706734
function const_prop_methodinstance_heuristic(
707735
interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance,
708-
argtypes::Vector{Any}, sv::InferenceState)
736+
(; argtypes)::ArgInfo, sv::InferenceState)
709737
method = match.method
710738
if method.is_for_opaque_closure
711739
# Not inlining an opaque closure can be very expensive, so be generous
@@ -835,7 +863,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
835863
return Any[Vararg{Any}], nothing
836864
end
837865
@assert !isvarargtype(itertype)
838-
call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], sv)
866+
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), sv)
839867
stateordonet = call.rt
840868
info = call.info
841869
# Return Bottom if this is not an iterator.
@@ -869,7 +897,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
869897
valtype = getfield_tfunc(stateordonet, Const(1))
870898
push!(ret, valtype)
871899
statetype = nstatetype
872-
call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv)
900+
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv)
873901
stateordonet = call.rt
874902
stateordonet_widened = widenconst(stateordonet)
875903
push!(calls, call)
@@ -904,7 +932,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
904932
end
905933
valtype = tmerge(valtype, nounion.parameters[1])
906934
statetype = tmerge(statetype, nounion.parameters[2])
907-
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
935+
stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv).rt
908936
stateordonet_widened = widenconst(stateordonet)
909937
end
910938
if valtype !== Union{}
@@ -993,7 +1021,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
9931021
break
9941022
end
9951023
end
996-
call = abstract_call(interp, nothing, ct, sv, max_methods)
1024+
call = abstract_call(interp, ArgInfo(nothing, ct), sv, max_methods)
9971025
push!(retinfos, ApplyCallInfo(call.info, arginfo))
9981026
res = tmerge(res, call.rt)
9991027
if bail_out_apply(interp, res, sv)
@@ -1057,8 +1085,8 @@ function argtype_tail(argtypes::Vector{Any}, i::Int)
10571085
return argtypes[i:n]
10581086
end
10591087

1060-
function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}},
1061-
argtypes::Vector{Any}, sv::InferenceState, max_methods::Int)
1088+
function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs, argtypes)::ArgInfo,
1089+
sv::InferenceState, max_methods::Int)
10621090
@nospecialize f
10631091
la = length(argtypes)
10641092
if f === ifelse && fargs isa Vector{Any} && la == 4
@@ -1190,7 +1218,7 @@ function abstract_call_unionall(argtypes::Vector{Any})
11901218
return Any
11911219
end
11921220

1193-
function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
1221+
function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, sv::InferenceState)
11941222
ft′ = argtype_by_index(argtypes, 2)
11951223
ft = widenconst(ft′)
11961224
ft === Bottom && return CallMeta(Bottom, false)
@@ -1218,14 +1246,17 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
12181246
# since some checks within `abstract_call_method_with_const_args` seem a bit costly
12191247
const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing))
12201248
argtypes′ = argtypes[4:end]
1221-
const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
12221249
pushfirst!(argtypes′, ft)
1250+
fargs′ = fargs[4:end]
1251+
pushfirst!(fargs′, fargs[1])
1252+
arginfo = ArgInfo(fargs′, argtypes′)
1253+
const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
12231254
# # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons
12241255
# for i in 1:length(argtypes′)
12251256
# t, a = ti.parameters[i], argtypes′[i]
12261257
# argtypes′[i] = t ⊑ a ? t : a
12271258
# end
1228-
const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), argtypes′, match, sv, false)
1259+
const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), arginfo, match, sv, false)
12291260
if const_result !== nothing
12301261
const_rt, const_result = const_result
12311262
if const_rt !== rt && const_rt rt
@@ -1237,21 +1268,20 @@ end
12371268

12381269
# call where the function is known exactly
12391270
function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
1240-
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
1241-
sv::InferenceState,
1271+
arginfo::ArgInfo, sv::InferenceState,
12421272
max_methods::Int = InferenceParams(interp).MAX_METHODS)
1243-
1273+
(; fargs, argtypes) = arginfo
12441274
la = length(argtypes)
12451275

12461276
if isa(f, Builtin)
12471277
if f === _apply_iterate
12481278
return abstract_apply(interp, argtypes, sv, max_methods)
12491279
elseif f === invoke
1250-
return abstract_invoke(interp, argtypes, sv)
1280+
return abstract_invoke(interp, arginfo, sv)
12511281
elseif f === modifyfield!
12521282
return abstract_modifyfield!(interp, argtypes, sv)
12531283
end
1254-
return CallMeta(abstract_call_builtin(interp, f, fargs, argtypes, sv, max_methods), false)
1284+
return CallMeta(abstract_call_builtin(interp, f, arginfo, sv, max_methods), false)
12551285
elseif f === Core.kwfunc
12561286
if la == 2
12571287
ft = widenconst(argtypes[2])
@@ -1284,12 +1314,12 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
12841314
# handle Conditional propagation through !Bool
12851315
aty = argtypes[2]
12861316
if isa(aty, Conditional)
1287-
call = abstract_call_gf_by_type(interp, f, fargs, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
1317+
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
12881318
return CallMeta(Conditional(aty.var, aty.elsetype, aty.vtype), call.info)
12891319
end
12901320
elseif la == 3 && istopfunction(f, :!==)
12911321
# mark !== as exactly a negated call to ===
1292-
rty = abstract_call_known(interp, (===), fargs, argtypes, sv).rt
1322+
rty = abstract_call_known(interp, (===), arginfo, sv).rt
12931323
if isa(rty, Conditional)
12941324
return CallMeta(Conditional(rty.var, rty.elsetype, rty.vtype), false) # swap if-else
12951325
elseif isa(rty, Const)
@@ -1305,7 +1335,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
13051335
fargs = nothing
13061336
end
13071337
argtypes = Any[typeof(<:), argtypes[3], argtypes[2]]
1308-
return CallMeta(abstract_call_known(interp, <:, fargs, argtypes, sv).rt, false)
1338+
return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv).rt, false)
13091339
elseif la == 2 &&
13101340
(a2 = argtypes[2]; isa(a2, Const)) && (svecval = a2.val; isa(svecval, SimpleVector)) &&
13111341
istopfunction(f, :length)
@@ -1328,7 +1358,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
13281358
return CallMeta(val === false ? Type : val, MethodResultPure())
13291359
end
13301360
atype = argtypes_to_type(argtypes)
1331-
return abstract_call_gf_by_type(interp, f, fargs, argtypes, atype, sv, max_methods)
1361+
return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods)
13321362
end
13331363

13341364
function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
@@ -1341,8 +1371,8 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part
13411371
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
13421372
info = OpaqueClosureCallInfo(match)
13431373
if !result.edgecycle
1344-
const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
1345-
match, sv, closure.isva)
1374+
const_result = abstract_call_method_with_const_args(interp, result, closure,
1375+
ArgInfo(nothing, argtypes), match, sv, closure.isva)
13461376
if const_result !== nothing
13471377
const_rettype, const_result = const_result
13481378
if const_rettype rt
@@ -1365,9 +1395,9 @@ function most_general_argtypes(closure::PartialOpaque)
13651395
end
13661396

13671397
# call where the function is any lattice element
1368-
function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
1398+
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo,
13691399
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
1370-
#print("call ", e.args[1], argtypes, "\n\n")
1400+
argtypes = arginfo.argtypes
13711401
ft = argtypes[1]
13721402
f = singleton_type(ft)
13731403
if isa(ft, PartialOpaque)
@@ -1381,9 +1411,9 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
13811411
add_remark!(interp, sv, "Could not identify method table for call")
13821412
return CallMeta(Any, false)
13831413
end
1384-
return abstract_call_gf_by_type(interp, nothing, fargs, argtypes, argtypes_to_type(argtypes), sv, max_methods)
1414+
return abstract_call_gf_by_type(interp, nothing, arginfo, argtypes_to_type(argtypes), sv, max_methods)
13851415
end
1386-
return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods)
1416+
return abstract_call_known(interp, f, arginfo, sv, max_methods)
13871417
end
13881418

13891419
function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
@@ -1433,7 +1463,7 @@ function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::V
14331463
# this may be the wrong world for the call,
14341464
# but some of the result is likely to be valid anyways
14351465
# and that may help generate better codegen
1436-
abstract_call(interp, nothing, at, sv)
1466+
abstract_call(interp, ArgInfo(nothing, at), sv)
14371467
nothing
14381468
end
14391469

@@ -1507,7 +1537,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15071537
if argtypes === nothing
15081538
t = Bottom
15091539
else
1510-
callinfo = abstract_call(interp, ea, argtypes, sv)
1540+
callinfo = abstract_call(interp, ArgInfo(ea, argtypes), sv)
15111541
sv.stmt_info[sv.currpc] = callinfo.info
15121542
t = callinfo.rt
15131543
end

0 commit comments

Comments
 (0)