Skip to content

Commit c33a47d

Browse files
committed
inference: forward Conditional inter-procedurally
The PR #38905 only "back-propagates" conditional constraint (from callee to caller), but currently we don't "forward" it (caller to callee), and so inter-procedural constraint propagation won't happen for e.g.: ```julia ifelselike(cnd, x, y) = cnd ? x : y @test Base.return_types((Any,Int,)) do x, y ifelselike(isa(x, Int), x, y) end |> only == Int ``` This commit complements #38905 and enables further inter-procedural conditional constraint propagation by forwarding `Conditional` to callees when it imposes a constraint on any other argument, during constant propagation.
1 parent cd19e97 commit c33a47d

File tree

5 files changed

+163
-50
lines changed

5 files changed

+163
-50
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 73 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function is_improvable(@nospecialize(rtype))
2929
end
3030

3131
function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
32-
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype),
32+
(; fargs, argtypes)::ArgInfo, @nospecialize(atype),
3333
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
3434
if sv.params.unoptimize_throw_blocks && is_stmt_throw_block(get_curr_ssaflag(sv))
3535
add_remark!(interp, sv, "Skipped call in throw block")
@@ -85,7 +85,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
8585
push!(edges, edge)
8686
end
8787
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
88-
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
88+
arginfo = ArgInfo(fargs, this_argtypes)
89+
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
8990
if const_result !== nothing
9091
const_rt, const_result = const_result
9192
if const_rt !== rt && const_rt rt
@@ -110,7 +111,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
110111
# try constant propagation with argtypes for this match
111112
# this is in preparation for inlining, or improving the return result
112113
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
113-
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
114+
arginfo = ArgInfo(fargs, this_argtypes)
115+
const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false)
114116
if const_result !== nothing
115117
const_this_rt, const_result = const_result
116118
if const_this_rt !== this_rt && const_this_rt this_rt
@@ -523,13 +525,13 @@ struct MethodCallResult
523525
end
524526

525527
function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult,
526-
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
528+
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
527529
sv::InferenceState, va_override::Bool)
528-
mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv)
530+
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
529531
mi === nothing && return nothing
530532
# try constant prop'
531533
inf_cache = get_inference_cache(interp)
532-
inf_result = cache_lookup(mi, argtypes, inf_cache)
534+
inf_result = cache_lookup(mi, arginfo.argtypes, inf_cache)
533535
if inf_result === nothing
534536
# if there might be a cycle, check to make sure we don't end up
535537
# calling ourselves here.
@@ -545,7 +547,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
545547
return nothing
546548
end
547549
end
548-
inf_result = InferenceResult(mi, argtypes, va_override)
550+
inf_result = InferenceResult(mi; arginfo, va_override)
549551
if !any(inf_result.overridden_by_const)
550552
add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes")
551553
return nothing
@@ -565,7 +567,7 @@ end
565567
# if there's a possibility we could get a better result (hopefully without doing too much work)
566568
# returns `MethodInstance` with constant arguments, returns nothing otherwise
567569
function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::MethodCallResult,
568-
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
570+
@nospecialize(f), arginfo::ArgInfo, match::MethodMatch,
569571
sv::InferenceState)
570572
if !InferenceParams(interp).ipo_constant_propagation
571573
add_remark!(interp, sv, "[constprop] Disabled by parameter")
@@ -580,14 +582,14 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
580582
force || const_prop_entry_heuristic(interp, result, sv) || return nothing
581583
nargs::Int = method.nargs
582584
method.isva && (nargs -= 1)
583-
length(argtypes) < nargs && return nothing
584-
if !(const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, result.rt))
585+
length(arginfo.argtypes) < nargs && return nothing
586+
if !(const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, result.rt))
585587
add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics")
586588
return nothing
587589
end
588-
allconst = is_allconst(argtypes)
590+
allconst = is_allconst(arginfo)
589591
if !force
590-
if !const_prop_function_heuristic(interp, f, argtypes, nargs, allconst)
592+
if !const_prop_function_heuristic(interp, f, arginfo, nargs, allconst)
591593
add_remark!(interp, sv, "[constprop] Disabled by function heuristic")
592594
return nothing
593595
end
@@ -599,7 +601,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
599601
return nothing
600602
end
601603
mi = mi::MethodInstance
602-
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, argtypes, sv)
604+
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
603605
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
604606
return nothing
605607
end
@@ -617,8 +619,11 @@ function const_prop_entry_heuristic(interp::AbstractInterpreter, result::MethodC
617619
end
618620

619621
# see if propagating constants may be worthwhile
620-
function const_prop_argument_heuristic(interp::AbstractInterpreter, argtypes::Vector{Any})
622+
function const_prop_argument_heuristic(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo)
621623
for a in argtypes
624+
if isa(a, Conditional) && fargs !== nothing
625+
return is_const_prop_profitable_conditional(a, fargs)
626+
end
622627
a = widenconditional(a)
623628
if has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a)
624629
return true
@@ -642,13 +647,34 @@ function is_const_prop_profitable_arg(@nospecialize(arg))
642647
return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val))
643648
end
644649

650+
function is_const_prop_profitable_conditional(cnd::Conditional, fargs::Vector{Any})
651+
slotid = find_constrained_arg(cnd, fargs)
652+
if slotid !== nothing
653+
return true
654+
end
655+
return is_const_prop_profitable_arg(widenconditional(cnd))
656+
end
657+
658+
function find_constrained_arg(cnd::Conditional, fargs::Vector{Any})
659+
slot = cnd.var
660+
return findfirst(fargs) do @nospecialize(x)
661+
x === slot
662+
end
663+
end
664+
645665
function const_prop_rettype_heuristic(interp::AbstractInterpreter, @nospecialize(rettype))
646666
return improvable_via_constant_propagation(rettype)
647667
end
648668

649-
function is_allconst(argtypes::Vector{Any})
669+
function is_allconst((; fargs, argtypes)::ArgInfo)
650670
for a in argtypes
671+
if isa(a, Conditional) && fargs !== nothing
672+
if is_const_prop_profitable_conditional(a, fargs)
673+
continue
674+
end
675+
end
651676
a = widenconditional(a)
677+
# TODO unify these condition with `has_nontrivial_const_info`
652678
if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque)
653679
return false
654680
end
@@ -663,7 +689,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method:
663689
istopfunction(f, :setproperty!)
664690
end
665691

666-
function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, nargs::Int, allconst::Bool)
692+
function const_prop_function_heuristic(
693+
interp::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo,
694+
nargs::Int, allconst::Bool)
667695
if nargs > 1
668696
if istopfunction(f, :getindex) || istopfunction(f, :setindex!)
669697
arrty = argtypes[2]
@@ -705,7 +733,7 @@ end
705733
# result anyway.
706734
function const_prop_methodinstance_heuristic(
707735
interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance,
708-
argtypes::Vector{Any}, sv::InferenceState)
736+
(; argtypes)::ArgInfo, sv::InferenceState)
709737
method = match.method
710738
if method.is_for_opaque_closure
711739
# Not inlining an opaque closure can be very expensive, so be generous
@@ -832,7 +860,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
832860
return Any[Vararg{Any}], nothing
833861
end
834862
@assert !isvarargtype(itertype)
835-
call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], sv)
863+
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), sv)
836864
stateordonet = call.rt
837865
info = call.info
838866
# Return Bottom if this is not an iterator.
@@ -866,7 +894,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
866894
valtype = getfield_tfunc(stateordonet, Const(1))
867895
push!(ret, valtype)
868896
statetype = nstatetype
869-
call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv)
897+
call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv)
870898
stateordonet = call.rt
871899
stateordonet_widened = widenconst(stateordonet)
872900
push!(calls, call)
@@ -901,7 +929,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
901929
end
902930
valtype = tmerge(valtype, nounion.parameters[1])
903931
statetype = tmerge(statetype, nounion.parameters[2])
904-
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
932+
stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv).rt
905933
stateordonet_widened = widenconst(stateordonet)
906934
end
907935
if valtype !== Union{}
@@ -990,7 +1018,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::
9901018
break
9911019
end
9921020
end
993-
call = abstract_call(interp, nothing, ct, sv, max_methods)
1021+
call = abstract_call(interp, ArgInfo(nothing, ct), sv, max_methods)
9941022
push!(retinfos, ApplyCallInfo(call.info, arginfo))
9951023
res = tmerge(res, call.rt)
9961024
if bail_out_apply(interp, res, sv)
@@ -1054,8 +1082,8 @@ function argtype_tail(argtypes::Vector{Any}, i::Int)
10541082
return argtypes[i:n]
10551083
end
10561084

1057-
function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}},
1058-
argtypes::Vector{Any}, sv::InferenceState, max_methods::Int)
1085+
function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs, argtypes)::ArgInfo,
1086+
sv::InferenceState, max_methods::Int)
10591087
@nospecialize f
10601088
la = length(argtypes)
10611089
if f === ifelse && fargs isa Vector{Any} && la == 4
@@ -1188,7 +1216,7 @@ function abstract_call_unionall(argtypes::Vector{Any})
11881216
return Any
11891217
end
11901218

1191-
function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
1219+
function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, sv::InferenceState)
11921220
ft′ = argtype_by_index(argtypes, 2)
11931221
ft = widenconst(ft′)
11941222
ft === Bottom && return CallMeta(Bottom, false)
@@ -1216,14 +1244,17 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
12161244
# since some checks within `abstract_call_method_with_const_args` seem a bit costly
12171245
const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing))
12181246
argtypes′ = argtypes[4:end]
1219-
const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
12201247
pushfirst!(argtypes′, ft)
1248+
fargs′ = fargs[4:end]
1249+
pushfirst!(fargs′, fargs[1])
1250+
arginfo = ArgInfo(fargs′, argtypes′)
1251+
const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
12211252
# # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons
12221253
# for i in 1:length(argtypes′)
12231254
# t, a = ti.parameters[i], argtypes′[i]
12241255
# argtypes′[i] = t ⊑ a ? t : a
12251256
# end
1226-
const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), argtypes′, match, sv, false)
1257+
const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), arginfo, match, sv, false)
12271258
if const_result !== nothing
12281259
const_rt, const_result = const_result
12291260
if const_rt !== rt && const_rt rt
@@ -1235,21 +1266,20 @@ end
12351266

12361267
# call where the function is known exactly
12371268
function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
1238-
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
1239-
sv::InferenceState,
1269+
arginfo::ArgInfo, sv::InferenceState,
12401270
max_methods::Int = InferenceParams(interp).MAX_METHODS)
1241-
1271+
(; fargs, argtypes) = arginfo
12421272
la = length(argtypes)
12431273

12441274
if isa(f, Builtin)
12451275
if f === _apply_iterate
12461276
return abstract_apply(interp, argtypes, sv, max_methods)
12471277
elseif f === invoke
1248-
return abstract_invoke(interp, argtypes, sv)
1278+
return abstract_invoke(interp, arginfo, sv)
12491279
elseif f === modifyfield!
12501280
return abstract_modifyfield!(interp, argtypes, sv)
12511281
end
1252-
return CallMeta(abstract_call_builtin(interp, f, fargs, argtypes, sv, max_methods), false)
1282+
return CallMeta(abstract_call_builtin(interp, f, arginfo, sv, max_methods), false)
12531283
elseif f === Core.kwfunc
12541284
if la == 2
12551285
ft = widenconst(argtypes[2])
@@ -1282,12 +1312,12 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
12821312
# handle Conditional propagation through !Bool
12831313
aty = argtypes[2]
12841314
if isa(aty, Conditional)
1285-
call = abstract_call_gf_by_type(interp, f, fargs, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
1315+
call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)`
12861316
return CallMeta(Conditional(aty.var, aty.elsetype, aty.vtype), call.info)
12871317
end
12881318
elseif la == 3 && istopfunction(f, :!==)
12891319
# mark !== as exactly a negated call to ===
1290-
rty = abstract_call_known(interp, (===), fargs, argtypes, sv).rt
1320+
rty = abstract_call_known(interp, (===), arginfo, sv).rt
12911321
if isa(rty, Conditional)
12921322
return CallMeta(Conditional(rty.var, rty.elsetype, rty.vtype), false) # swap if-else
12931323
elseif isa(rty, Const)
@@ -1303,7 +1333,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
13031333
fargs = nothing
13041334
end
13051335
argtypes = Any[typeof(<:), argtypes[3], argtypes[2]]
1306-
return CallMeta(abstract_call_known(interp, <:, fargs, argtypes, sv).rt, false)
1336+
return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv).rt, false)
13071337
elseif la == 2 &&
13081338
(a2 = argtypes[2]; isa(a2, Const)) && (svecval = a2.val; isa(svecval, SimpleVector)) &&
13091339
istopfunction(f, :length)
@@ -1326,7 +1356,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
13261356
return CallMeta(val === false ? Type : val, MethodResultPure())
13271357
end
13281358
atype = argtypes_to_type(argtypes)
1329-
return abstract_call_gf_by_type(interp, f, fargs, argtypes, atype, sv, max_methods)
1359+
return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods)
13301360
end
13311361

13321362
function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
@@ -1339,8 +1369,8 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part
13391369
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
13401370
info = OpaqueClosureCallInfo(match)
13411371
if !result.edgecycle
1342-
const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
1343-
match, sv, closure.isva)
1372+
const_result = abstract_call_method_with_const_args(interp, result, closure,
1373+
ArgInfo(nothing, argtypes), match, sv, closure.isva)
13441374
if const_result !== nothing
13451375
const_rettype, const_result = const_result
13461376
if const_rettype rt
@@ -1363,9 +1393,9 @@ function most_general_argtypes(closure::PartialOpaque)
13631393
end
13641394

13651395
# call where the function is any lattice element
1366-
function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
1396+
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo,
13671397
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
1368-
#print("call ", e.args[1], argtypes, "\n\n")
1398+
argtypes = arginfo.argtypes
13691399
ft = argtypes[1]
13701400
f = singleton_type(ft)
13711401
if isa(ft, PartialOpaque)
@@ -1379,9 +1409,9 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
13791409
add_remark!(interp, sv, "Could not identify method table for call")
13801410
return CallMeta(Any, false)
13811411
end
1382-
return abstract_call_gf_by_type(interp, nothing, fargs, argtypes, argtypes_to_type(argtypes), sv, max_methods)
1412+
return abstract_call_gf_by_type(interp, nothing, arginfo, argtypes_to_type(argtypes), sv, max_methods)
13831413
end
1384-
return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods)
1414+
return abstract_call_known(interp, f, arginfo, sv, max_methods)
13851415
end
13861416

13871417
function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
@@ -1428,7 +1458,7 @@ function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::V
14281458
# this may be the wrong world for the call,
14291459
# but some of the result is likely to be valid anyways
14301460
# and that may help generate better codegen
1431-
abstract_call(interp, nothing, at, sv)
1461+
abstract_call(interp, ArgInfo(nothing, at), sv)
14321462
nothing
14331463
end
14341464

@@ -1502,7 +1532,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15021532
if argtypes === nothing
15031533
t = Bottom
15041534
else
1505-
callinfo = abstract_call(interp, ea, argtypes, sv)
1535+
callinfo = abstract_call(interp, ArgInfo(ea, argtypes), sv)
15061536
sv.stmt_info[sv.currpc] = callinfo.info
15071537
t = callinfo.rt
15081538
end

0 commit comments

Comments
 (0)