Skip to content

Commit f143f7b

Browse files
committed
inference: inter-procedural conditional constraint back-propagation
This PR propagates `Conditional`s inter-procedurally when a `Conditional` at return site imposes a constraint on the call arguments. When inference exits local frame and the return type is annotated as `Conditional`, it will be converted into `InterConditional` object, which is implemented in `Core` and can be directly put into the global cache. Finally after going back to caller frame, `InterConditional` will be re-converted into `Conditional` in the context of the caller frame. So now some simple "is-wrapper" functions will propagate its constraint as expected, e.g.: ```julia isaint(a) = isa(a, Int) @test Base.return_types((Any,)) do a isaint(a) && return a # a::Int return 0 end == [Int] ``` This PR also tweaks `isnothing` and `ismissing` so that there is no longer any inferrability penalties to use them instead of `x === nothing` or `x === missing` e.g.: ```julia @test Base.return_types((Union{Nothing,Int},)) do a isnothing(a) && return 0 return a # a::Int end == [Int] ``` (and now we don't need something like #38636) There're certain limitations around. One of the biggest ones would be that it can propagate constrains only on a single argument, and it fails to back-propagate constrains when there're multiple conditions on different slots, e.g. `Meta.isexpr` can't still propagate its type constraint to the caller: ```julia @test_broken Base.return_types((Any,)) do x Meta.isexpr(x, :call) && return x # still x::Any but ideally x::Expr return nothing end == [Nothing,Expr] ``` (and because of this reason, this PR can't close #37342)
1 parent f184f05 commit f143f7b

File tree

14 files changed

+144
-20
lines changed

14 files changed

+144
-20
lines changed

base/boot.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecial
424424
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
425425
eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))))
426426
eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))))
427+
eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype))))
427428
eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) =
428429
$(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))))
429430

base/compiler/abstractinterpretation.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,10 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
11441144
end
11451145
callinfo = abstract_call(interp, ea, argtypes, sv)
11461146
sv.stmt_info[sv.currpc] = callinfo.info
1147-
t = callinfo.rt
1147+
rt = callinfo.rt
1148+
t = isa(rt, InterConditional) ?
1149+
transform_from_interconditional(rt, ea) :
1150+
rt
11481151
elseif e.head === :new
11491152
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
11501153
if isconcretetype(t) && !t.mutable
@@ -1255,6 +1258,19 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
12551258
return t
12561259
end
12571260

1261+
# try to convert interprocedural-conditional constraint from callee into constraints for
1262+
# the current frame
1263+
function transform_from_interconditional(rt::InterConditional, ea::Vector{Any})
1264+
i = rt.slot
1265+
if checkbounds(Bool, ea, i)
1266+
e = @inbounds ea[i]
1267+
if isa(e, Slot)
1268+
return Conditional(e, rt.vtype, rt.elsetype)
1269+
end
1270+
end
1271+
return widenconditional(rt)
1272+
end
1273+
12581274
function abstract_eval_global(M::Module, s::Symbol)
12591275
if isdefined(M,s) && isconst(M,s)
12601276
return Const(getfield(M,s))
@@ -1338,8 +1354,11 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13381354
end
13391355
elseif isa(stmt, ReturnNode)
13401356
pc´ = n + 1
1341-
rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame))
1342-
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
1357+
rt = abstract_eval_value(interp, stmt.val, s[pc], frame)
1358+
if !isa(rt, Const) &&
1359+
!isa(rt, Type) &&
1360+
!isa(rt, PartialStruct) &&
1361+
!isa(rt, Conditional)
13431362
# only propagate information we know we can store
13441363
# and is valid inter-procedurally
13451364
rt = widenconst(rt)

base/compiler/tfuncs.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,6 +1649,9 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
16491649
return Const(Union{})
16501650
end
16511651
rt = abstract_call(interp, nothing, argtypes_vec, sv, -1).rt
1652+
if isa(rt, InterConditional)
1653+
rt = widenconditional(rt)
1654+
end
16521655
if isa(rt, Const)
16531656
# output was computed to be constant
16541657
return Const(typeof(rt.val))

base/compiler/typeinfer.jl

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -286,28 +286,33 @@ end
286286
function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::Any),
287287
valid_worlds::WorldRange)
288288
local const_flags::Int32
289+
res = result.result
290+
rettype = widenconst(res)
289291
if inferred_result isa Const
290292
# use constant calling convention
291293
rettype_const = (result.src::Const).val
292294
const_flags = 0x3
293295
inferred_result = nothing
294296
else
295-
if isa(result.result, Const)
296-
rettype_const = (result.result::Const).val
297+
if isa(res, Const)
298+
rettype_const = res.val
297299
const_flags = 0x2
298-
elseif isconstType(result.result)
299-
rettype_const = result.result.parameters[1]
300+
elseif isconstType(res)
301+
rettype_const = res.parameters[1]
300302
const_flags = 0x2
301-
elseif isa(result.result, PartialStruct)
302-
rettype_const = (result.result::PartialStruct).fields
303+
elseif isa(res, PartialStruct)
304+
rettype_const = res.fields
303305
const_flags = 0x2
304306
else
307+
if isa(res, Conditional)
308+
rettype = transform_to_interconditional(res, length(result.argtypes))
309+
end
305310
rettype_const = nothing
306311
const_flags = 0x00
307312
end
308313
end
309314
return CodeInstance(result.linfo,
310-
widenconst(result.result), rettype_const, inferred_result,
315+
rettype, rettype_const, inferred_result,
311316
const_flags, first(valid_worlds), last(valid_worlds))
312317
end
313318

@@ -759,15 +764,32 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
759764
end
760765
typeinf(interp, frame)
761766
update_valid_age!(frame, caller)
762-
return frame.bestguess, frame.inferred ? mi : nothing
767+
bestguess = frame.bestguess
768+
if isa(bestguess, Conditional)
769+
bestguess = transform_to_interconditional(bestguess, length(result.argtypes))
770+
end
771+
return bestguess, frame.inferred ? mi : nothing
763772
elseif frame === true
764773
# unresolvable cycle
765774
return Any, nothing
766775
end
767776
# return the current knowledge about this cycle
768777
frame = frame::InferenceState
769778
update_valid_age!(frame, caller)
770-
return frame.bestguess, nothing
779+
bestguess = frame.bestguess
780+
if isa(bestguess, Conditional)
781+
bestguess = transform_to_interconditional(bestguess, length(frame.result.argtypes))
782+
end
783+
return bestguess, nothing
784+
end
785+
786+
function transform_to_interconditional(bestguess::Conditional, nargs::Int)
787+
# keep this conditional only when it constrains a slot within call arguments
788+
if 1 < slot_id(bestguess.var) <= nargs
789+
return InterConditional(slot_id(bestguess.var), bestguess.vtype, bestguess.elsetype)
790+
else
791+
return widenconditional(bestguess)
792+
end
771793
end
772794

773795
#### entry points for inferring a MethodInstance given a type signature ####

base/compiler/typelattice.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# structs/constants #
55
#####################
66

7-
# N.B.: Const/PartialStruct are defined in Core, to allow them to be used
7+
# N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used
88
# inside the global code cache.
99
#
1010
# # The type of a value might be constant
@@ -18,7 +18,6 @@
1818
# end
1919
import Core: Const, PartialStruct
2020

21-
2221
# The type of this value might be Bool.
2322
# However, to enable a limited amount of back-propagagation,
2423
# we also keep some information about how this Bool value was created.
@@ -45,6 +44,15 @@ struct Conditional
4544
end
4645
end
4746

47+
# # similar to `Conditional`, but conveys inter-procedural constrains imposed on call arguments
48+
# struct InterConditional
49+
# slot::Int
50+
# vtype
51+
# elsetype
52+
# end
53+
import Core: InterConditional
54+
const ConditionalWrapper = Union{Conditional,InterConditional}
55+
4856
struct PartialTypeVar
4957
tv::TypeVar
5058
# N.B.: Currently unused, but would allow turning something back
@@ -105,7 +113,7 @@ function issubconditional(a::Conditional, b::Conditional)
105113
end
106114

107115
maybe_extract_const_bool(c::Const) = isa(c.val, Bool) ? c.val : nothing
108-
function maybe_extract_const_bool(c::Conditional)
116+
function maybe_extract_const_bool(c::ConditionalWrapper)
109117
(c.vtype === Bottom && !(c.elsetype === Bottom)) && return false
110118
(c.elsetype === Bottom && !(c.vtype === Bottom)) && return true
111119
nothing
@@ -205,6 +213,7 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b))
205213
end
206214

207215
widenconst(c::Conditional) = Bool
216+
widenconst(c::InterConditional) = Bool
208217
function widenconst(c::Const)
209218
if isa(c.val, Type)
210219
if isvarargtype(c.val)
@@ -237,7 +246,7 @@ end
237246
@inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n, o)))
238247

239248
widenconditional(@nospecialize typ) = typ
240-
function widenconditional(typ::Conditional)
249+
function widenconditional(typ::ConditionalWrapper)
241250
if typ.vtype === Union{}
242251
return Const(false)
243252
elseif typ.elsetype === Union{}

base/compiler/typelimits.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,35 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
327327
end
328328
return Bool
329329
end
330+
# type-lattice for InterConditional wrapper, InterConditional won't be merged with Conditional
331+
if isa(typea, InterConditional) && isa(typeb, Const)
332+
if typeb.val === true
333+
typeb = InterConditional(typea.slot, Any, Union{})
334+
elseif typeb.val === false
335+
typeb = InterConditional(typea.slot, Union{}, Any)
336+
end
337+
end
338+
if isa(typeb, InterConditional) && isa(typea, Const)
339+
if typea.val === true
340+
typea = InterConditional(typeb.slot, Any, Union{})
341+
elseif typea.val === false
342+
typea = InterConditional(typeb.slot, Union{}, Any)
343+
end
344+
end
345+
if isa(typea, InterConditional) && isa(typeb, InterConditional)
346+
if typea.slot === typeb.slot
347+
vtype = tmerge(typea.vtype, typeb.vtype)
348+
elsetype = tmerge(typea.elsetype, typeb.elsetype)
349+
if vtype != elsetype
350+
return InterConditional(typea.slot, vtype, elsetype)
351+
end
352+
end
353+
val = maybe_extract_const_bool(typea)
354+
if val isa Bool && val === maybe_extract_const_bool(typeb)
355+
return Const(val)
356+
end
357+
return Bool
358+
end
330359
if (isa(typea, PartialStruct) || isa(typea, Const)) &&
331360
(isa(typeb, PartialStruct) || isa(typeb, Const)) &&
332361
widenconst(typea) === widenconst(typeb)

base/essentials.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,8 +813,7 @@ const missing = Missing()
813813
814814
Indicate whether `x` is [`missing`](@ref).
815815
"""
816-
ismissing(::Any) = false
817-
ismissing(::Missing) = true
816+
ismissing(x) = x === missing
818817

819818
function popfirst! end
820819

base/some.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ Return `true` if `x === nothing`, and return `false` if not.
6363
!!! compat "Julia 1.1"
6464
This function requires at least Julia 1.1.
6565
"""
66-
isnothing(::Any) = false
67-
isnothing(::Nothing) = true
66+
isnothing(x) = x === nothing
6867

6968

7069
"""

src/builtins.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,6 +1622,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
16221622
add_builtin("Argument", (jl_value_t*)jl_argument_type);
16231623
add_builtin("Const", (jl_value_t*)jl_const_type);
16241624
add_builtin("PartialStruct", (jl_value_t*)jl_partial_struct_type);
1625+
add_builtin("InterConditional", (jl_value_t*)jl_interconditional_type);
16251626
add_builtin("MethodMatch", (jl_value_t*)jl_method_match_type);
16261627
add_builtin("IntrinsicFunction", (jl_value_t*)jl_intrinsic_type);
16271628
add_builtin("Function", (jl_value_t*)jl_function_type);

src/jl_exported_data.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
XX(jl_nothing_type) \
7171
XX(jl_number_type) \
7272
XX(jl_partial_struct_type) \
73+
XX(jl_interconditional_type) \
7374
XX(jl_phicnode_type) \
7475
XX(jl_phinode_type) \
7576
XX(jl_pinode_type) \

src/jltypes.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,6 +2302,10 @@ void jl_init_types(void) JL_GC_DISABLED
23022302
jl_perm_symsvec(2, "typ", "fields"),
23032303
jl_svec2(jl_any_type, jl_array_any_type), 0, 0, 2);
23042304

2305+
jl_interconditional_type = jl_new_datatype(jl_symbol("InterConditional"), core, jl_any_type, jl_emptysvec,
2306+
jl_perm_symsvec(3, "slot", "vtype", "elsetype"),
2307+
jl_svec(3, jl_long_type, jl_any_type, jl_any_type), 0, 0, 3);
2308+
23052309
jl_method_match_type = jl_new_datatype(jl_symbol("MethodMatch"), core, jl_any_type, jl_emptysvec,
23062310
jl_perm_symsvec(4, "spec_types", "sparams", "method", "fully_covers"),
23072311
jl_svec(4, jl_type_type, jl_simplevector_type, jl_method_type, jl_bool_type), 0, 0, 4);

src/julia.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_typedslot_type JL_GLOBALLY_ROOTED;
619619
extern JL_DLLIMPORT jl_datatype_t *jl_argument_type JL_GLOBALLY_ROOTED;
620620
extern JL_DLLIMPORT jl_datatype_t *jl_const_type JL_GLOBALLY_ROOTED;
621621
extern JL_DLLIMPORT jl_datatype_t *jl_partial_struct_type JL_GLOBALLY_ROOTED;
622+
extern JL_DLLIMPORT jl_datatype_t *jl_interconditional_type JL_GLOBALLY_ROOTED;
622623
extern JL_DLLIMPORT jl_datatype_t *jl_method_match_type JL_GLOBALLY_ROOTED;
623624
extern JL_DLLIMPORT jl_datatype_t *jl_simplevector_type JL_GLOBALLY_ROOTED;
624625
extern JL_DLLIMPORT jl_typename_t *jl_tuple_typename JL_GLOBALLY_ROOTED;

src/staticdata.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ jl_value_t **const*const get_tags(void) {
6868
INSERT_TAG(jl_returnnode_type);
6969
INSERT_TAG(jl_const_type);
7070
INSERT_TAG(jl_partial_struct_type);
71+
INSERT_TAG(jl_interconditional_type);
7172
INSERT_TAG(jl_method_match_type);
7273
INSERT_TAG(jl_pinode_type);
7374
INSERT_TAG(jl_phinode_type);

test/compiler/inference.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,6 +1719,41 @@ for expr25261 in opt25261[i:end]
17191719
end
17201720
@test foundslot
17211721

1722+
@testset "interprocedural conditional constraint propagation" begin
1723+
isaint(a) = isa(a, Int)
1724+
@test Base.return_types((Any,)) do a
1725+
isaint(a) && return a # a::Int
1726+
return 0
1727+
end == [Int]
1728+
eqnothing(a) = a === nothing
1729+
@test Base.return_types((Union{Nothing,Int},)) do a
1730+
eqnothing(a) && return 0
1731+
return a # a::Int
1732+
end == [Int]
1733+
1734+
# tests with base functions
1735+
@test Base.return_types((Any,)) do a
1736+
Base.Fix2(isa, Int)(a) && return sin(a) # a::Float64
1737+
return 0.0
1738+
end == [Float64]
1739+
@test Base.return_types((Union{Nothing,Int},)) do a
1740+
isnothing(a) && return 0
1741+
return a # a::Int
1742+
end == [Int]
1743+
1744+
# FIXME: we can't propagate conditional constraints interprocedurally when there're
1745+
# multiple possible conditions within the callee
1746+
ispositive(a) = isa(a, Int) && a > 0
1747+
@test_broken Base.return_types((Any,)) do a
1748+
ispositive(a) && return a # a::Int, ideally
1749+
return 0
1750+
end == [Int]
1751+
@test_broken Base.return_types((Any,)) do x
1752+
Meta.isexpr(x, :call) && return x # x::Expr, ideally
1753+
return nothing
1754+
end == [Nothing,Expr]
1755+
end
1756+
17221757
function f25579(g)
17231758
h = g[]
17241759
t = (h === nothing)

0 commit comments

Comments
 (0)