Skip to content

Commit e9ccda6

Browse files
committed
Merge branch 'backprop2' into moreconstantprop
2 parents 0960479 + 1a347ed commit e9ccda6

File tree

14 files changed

+342
-52
lines changed

14 files changed

+342
-52
lines changed

base/boot.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecial
424424
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
425425
eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))))
426426
eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))))
427+
eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype))))
427428
eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) =
428429
$(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))))
429430

base/compiler/abstractinterpretation.jl

Lines changed: 185 additions & 27 deletions
Large diffs are not rendered by default.

base/compiler/typeinfer.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,9 @@ function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::An
295295
elseif isa(result_type, PartialStruct)
296296
rettype_const = result_type.fields
297297
const_flags = 0x2
298+
elseif isa(result_type, InterConditional)
299+
rettype_const = result_type
300+
const_flags = 0x2
298301
else
299302
rettype_const = nothing
300303
const_flags = 0x00
@@ -770,14 +773,18 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
770773
code = get(code_cache(interp), mi, nothing)
771774
if code isa CodeInstance # return existing rettype if the code is already inferred
772775
update_valid_age!(caller, WorldRange(min_world(code), max_world(code)))
776+
rettype = code.rettype
773777
if isdefined(code, :rettype_const)
774-
if isa(code.rettype_const, Vector{Any}) && !(Vector{Any} <: code.rettype)
775-
return PartialStruct(code.rettype, code.rettype_const), mi
778+
rettype_const = code.rettype_const
779+
if isa(rettype_const, InterConditional)
780+
return rettype_const, mi
781+
elseif isa(rettype_const, Vector{Any}) && !(Vector{Any} <: rettype)
782+
return PartialStruct(rettype, rettype_const), mi
776783
else
777-
return Const(code.rettype_const), mi
784+
return Const(rettype_const), mi
778785
end
779786
else
780-
return code.rettype, mi
787+
return rettype, mi
781788
end
782789
end
783790
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0

base/compiler/typelattice.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# structs/constants #
55
#####################
66

7-
# N.B.: Const/PartialStruct are defined in Core, to allow them to be used
7+
# N.B.: Const/PartialStruct/InterConditional are defined in Core, to allow them to be used
88
# inside the global code cache.
99
#
1010
# # The type of a value might be constant
@@ -18,7 +18,6 @@
1818
# end
1919
import Core: Const, PartialStruct
2020

21-
2221
# The type of this value might be Bool.
2322
# However, to enable a limited amount of back-propagagation,
2423
# we also keep some information about how this Bool value was created.
@@ -45,6 +44,18 @@ struct Conditional
4544
end
4645
end
4746

47+
# # Similar to `Conditional`, but conveys inter-procedural constraints imposed on call arguments.
48+
# # This is separate from `Conditional` to catch logic errors: the lattice element name is InterConditional
49+
# # while processing a call, then Conditional everywhere else. Thus InterConditional does not appear in
50+
# # CompilerTypes—these type's usages are disjoint—though we define the lattice for InterConditional.
51+
# struct InterConditional
52+
# slot::Int
53+
# vtype
54+
# elsetype
55+
# end
56+
import Core: InterConditional
57+
const AnyConditional = Union{Conditional,InterConditional}
58+
4859
struct PartialTypeVar
4960
tv::TypeVar
5061
# N.B.: Currently unused, but would allow turning something back
@@ -101,11 +112,10 @@ const CompilerTypes = Union{MaybeUndef, Const, Conditional, NotFound, PartialStr
101112
# lattice logic #
102113
#################
103114

104-
function issubconditional(a::Conditional, b::Conditional)
105-
avar = a.var
106-
bvar = b.var
107-
if (isa(avar, Slot) && isa(bvar, Slot) && slot_id(avar) === slot_id(bvar)) ||
108-
(isa(avar, SSAValue) && isa(bvar, SSAValue) && avar === bvar)
115+
# `Conditional` and `InterConditional` are valid in opposite contexts
116+
# (i.e. local inference and inter-procedural call), as such they will never be compared
117+
function issubconditional(a::C, b::C) where {C<:AnyConditional}
118+
if is_same_conditionals(a, b)
109119
if a.vtype b.vtype
110120
if a.elsetype b.elsetype
111121
return true
@@ -115,8 +125,13 @@ function issubconditional(a::Conditional, b::Conditional)
115125
return false
116126
end
117127

118-
maybe_extract_const_bool(c::Const) = isa(c.val, Bool) ? c.val : nothing
119-
function maybe_extract_const_bool(c::Conditional)
128+
is_same_conditionals(a::Conditional, b::Conditional) = slot_id(a.var) === slot_id(b.var)
129+
is_same_conditionals(a::InterConditional, b::InterConditional) = a.slot === b.slot
130+
131+
is_lattice_bool(@nospecialize(typ)) = typ !== Bottom && typ Bool
132+
133+
maybe_extract_const_bool(c::Const) = (val = c.val; isa(val, Bool)) ? val : nothing
134+
function maybe_extract_const_bool(c::AnyConditional)
120135
(c.vtype === Bottom && !(c.elsetype === Bottom)) && return false
121136
(c.elsetype === Bottom && !(c.vtype === Bottom)) && return true
122137
nothing
@@ -145,14 +160,14 @@ function ⊑(@nospecialize(a), @nospecialize(b))
145160
b === Union{} && return false
146161
@assert !isa(a, TypeVar) "invalid lattice item"
147162
@assert !isa(b, TypeVar) "invalid lattice item"
148-
if isa(a, Conditional)
149-
if isa(b, Conditional)
163+
if isa(a, AnyConditional)
164+
if isa(b, AnyConditional)
150165
return issubconditional(a, b)
151166
elseif isa(b, Const) && isa(b.val, Bool)
152167
return maybe_extract_const_bool(a) === b.val
153168
end
154169
a = Bool
155-
elseif isa(b, Conditional)
170+
elseif isa(b, AnyConditional)
156171
return false
157172
end
158173
if isa(a, PartialStruct)
@@ -226,7 +241,7 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b))
226241
return a b && b a
227242
end
228243

229-
widenconst(c::Conditional) = Bool
244+
widenconst(c::AnyConditional) = Bool
230245
function widenconst(c::Const)
231246
if isa(c.val, Type)
232247
if isvarargtype(c.val)
@@ -260,7 +275,7 @@ end
260275
@inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n, o)))
261276

262277
widenconditional(@nospecialize typ) = typ
263-
function widenconditional(typ::Conditional)
278+
function widenconditional(typ::AnyConditional)
264279
if typ.vtype === Union{}
265280
return Const(false)
266281
elseif typ.elsetype === Union{}

base/compiler/typelimits.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
334334
end
335335
end
336336
if isa(typea, Conditional) && isa(typeb, Conditional)
337-
if typea.var === typeb.var
337+
if is_same_conditionals(typea, typeb)
338338
vtype = tmerge(typea.vtype, typeb.vtype)
339339
elsetype = tmerge(typea.elsetype, typeb.elsetype)
340340
if vtype != elsetype
@@ -347,6 +347,36 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
347347
end
348348
return Bool
349349
end
350+
# type-lattice for InterConditional wrapper, InterConditional will never be merged with Conditional
351+
if isa(typea, InterConditional) && isa(typeb, Const)
352+
if typeb.val === true
353+
typeb = InterConditional(typea.slot, Any, Union{})
354+
elseif typeb.val === false
355+
typeb = InterConditional(typea.slot, Union{}, Any)
356+
end
357+
end
358+
if isa(typeb, InterConditional) && isa(typea, Const)
359+
if typea.val === true
360+
typea = InterConditional(typeb.slot, Any, Union{})
361+
elseif typea.val === false
362+
typea = InterConditional(typeb.slot, Union{}, Any)
363+
end
364+
end
365+
if isa(typea, InterConditional) && isa(typeb, InterConditional)
366+
if is_same_conditionals(typea, typeb)
367+
vtype = tmerge(typea.vtype, typeb.vtype)
368+
elsetype = tmerge(typea.elsetype, typeb.elsetype)
369+
if vtype != elsetype
370+
return InterConditional(typea.slot, vtype, elsetype)
371+
end
372+
end
373+
val = maybe_extract_const_bool(typea)
374+
if val isa Bool && val === maybe_extract_const_bool(typeb)
375+
return Const(val)
376+
end
377+
return Bool
378+
end
379+
# type-lattice for Const and PartialStruct wrappers
350380
if (isa(typea, PartialStruct) || isa(typea, Const)) &&
351381
(isa(typeb, PartialStruct) || isa(typeb, Const)) &&
352382
widenconst(typea) === widenconst(typeb)

base/essentials.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,8 +813,7 @@ const missing = Missing()
813813
814814
Indicate whether `x` is [`missing`](@ref).
815815
"""
816-
ismissing(::Any) = false
817-
ismissing(::Missing) = true
816+
ismissing(x) = x === missing
818817

819818
function popfirst! end
820819

base/show.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1380,7 +1380,7 @@ function operator_associativity(s::Symbol)
13801380
end
13811381

13821382
is_expr(@nospecialize(ex), head::Symbol) = isa(ex, Expr) && (ex.head === head)
1383-
is_expr(@nospecialize(ex), head::Symbol, n::Int) = is_expr(ex, head) && length((ex::Expr).args) == n
1383+
is_expr(@nospecialize(ex), head::Symbol, n::Int) = is_expr(ex, head) && length(ex.args) == n
13841384

13851385
is_quoted(ex) = false
13861386
is_quoted(ex::QuoteNode) = true

base/some.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ Return `true` if `x === nothing`, and return `false` if not.
6363
!!! compat "Julia 1.1"
6464
This function requires at least Julia 1.1.
6565
"""
66-
isnothing(::Any) = false
67-
isnothing(::Nothing) = true
66+
isnothing(x) = x === nothing
6867

6968

7069
"""

src/builtins.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,6 +1611,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
16111611
add_builtin("Argument", (jl_value_t*)jl_argument_type);
16121612
add_builtin("Const", (jl_value_t*)jl_const_type);
16131613
add_builtin("PartialStruct", (jl_value_t*)jl_partial_struct_type);
1614+
add_builtin("InterConditional", (jl_value_t*)jl_interconditional_type);
16141615
add_builtin("MethodMatch", (jl_value_t*)jl_method_match_type);
16151616
add_builtin("IntrinsicFunction", (jl_value_t*)jl_intrinsic_type);
16161617
add_builtin("Function", (jl_value_t*)jl_function_type);

src/jl_exported_data.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
XX(jl_nothing_type) \
7272
XX(jl_number_type) \
7373
XX(jl_partial_struct_type) \
74+
XX(jl_interconditional_type) \
7475
XX(jl_phicnode_type) \
7576
XX(jl_phinode_type) \
7677
XX(jl_pinode_type) \

src/jltypes.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2308,6 +2308,10 @@ void jl_init_types(void) JL_GC_DISABLED
23082308
jl_perm_symsvec(2, "typ", "fields"),
23092309
jl_svec2(jl_any_type, jl_array_any_type), 0, 0, 2);
23102310

2311+
jl_interconditional_type = jl_new_datatype(jl_symbol("InterConditional"), core, jl_any_type, jl_emptysvec,
2312+
jl_perm_symsvec(3, "slot", "vtype", "elsetype"),
2313+
jl_svec(3, jl_long_type, jl_any_type, jl_any_type), 0, 0, 3);
2314+
23112315
jl_method_match_type = jl_new_datatype(jl_symbol("MethodMatch"), core, jl_any_type, jl_emptysvec,
23122316
jl_perm_symsvec(4, "spec_types", "sparams", "method", "fully_covers"),
23132317
jl_svec(4, jl_type_type, jl_simplevector_type, jl_method_type, jl_bool_type), 0, 0, 4);

src/julia.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ extern JL_DLLIMPORT jl_datatype_t *jl_typedslot_type JL_GLOBALLY_ROOTED;
633633
extern JL_DLLIMPORT jl_datatype_t *jl_argument_type JL_GLOBALLY_ROOTED;
634634
extern JL_DLLIMPORT jl_datatype_t *jl_const_type JL_GLOBALLY_ROOTED;
635635
extern JL_DLLIMPORT jl_datatype_t *jl_partial_struct_type JL_GLOBALLY_ROOTED;
636+
extern JL_DLLIMPORT jl_datatype_t *jl_interconditional_type JL_GLOBALLY_ROOTED;
636637
extern JL_DLLIMPORT jl_datatype_t *jl_method_match_type JL_GLOBALLY_ROOTED;
637638
extern JL_DLLIMPORT jl_datatype_t *jl_simplevector_type JL_GLOBALLY_ROOTED;
638639
extern JL_DLLIMPORT jl_typename_t *jl_tuple_typename JL_GLOBALLY_ROOTED;

src/staticdata.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ jl_value_t **const*const get_tags(void) {
6868
INSERT_TAG(jl_returnnode_type);
6969
INSERT_TAG(jl_const_type);
7070
INSERT_TAG(jl_partial_struct_type);
71+
INSERT_TAG(jl_interconditional_type);
7172
INSERT_TAG(jl_method_match_type);
7273
INSERT_TAG(jl_pinode_type);
7374
INSERT_TAG(jl_phinode_type);

test/compiler/inference.jl

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ tmerge_test(Tuple{}, Tuple{Complex, Vararg{Union{ComplexF32, ComplexF64}}},
109109
@test Core.Compiler.tmerge(Vector{Int}, Core.Compiler.tmerge(Vector{String}, Vector{Bool})) == Vector
110110
@test Core.Compiler.tmerge(Base.BitIntegerType, Union{}) === Base.BitIntegerType
111111
@test Core.Compiler.tmerge(Union{}, Base.BitIntegerType) === Base.BitIntegerType
112+
@test Core.Compiler.tmerge(Core.Compiler.InterConditional(1, Int, Union{}), Core.Compiler.InterConditional(2, String, Union{})) === Core.Compiler.Const(true)
112113

113114
struct SomethingBits
114115
x::Base.BitIntegerType
@@ -1255,7 +1256,8 @@ end
12551256
push!(constvec, 10)
12561257
@test @inferred(sizeof_constvec()) == sizeof(Int) * 4
12571258

1258-
test_const_return((x)->isdefined(x, :re), Tuple{ComplexF64}, true)
1259+
test_const_return(x->isdefined(x, :re), Tuple{ComplexF64}, true)
1260+
12591261
isdefined_f3(x) = isdefined(x, 3)
12601262
@test @inferred(isdefined_f3(())) == false
12611263
@test find_call(first(code_typed(isdefined_f3, Tuple{Tuple{Vararg{Int}}})[1]), isdefined, 3)
@@ -1722,6 +1724,77 @@ for expr25261 in opt25261[i:end]
17221724
end
17231725
@test foundslot
17241726

1727+
@testset "inter-procedural conditional constraint propagation" begin
1728+
# simple cases
1729+
isaint(a) = isa(a, Int)
1730+
@test Base.return_types((Any,)) do a
1731+
isaint(a) && return a # a::Int
1732+
return 0
1733+
end == Any[Int]
1734+
eqnothing(a) = a === nothing
1735+
@test Base.return_types((Union{Nothing,Int},)) do a
1736+
eqnothing(a) && return 0
1737+
return a # a::Int
1738+
end == Any[Int]
1739+
1740+
# more complicated cases
1741+
ispositive(a) = isa(a, Int) && a > 0
1742+
@test Base.return_types((Any,)) do a
1743+
ispositive(a) && return a # a::Int
1744+
return 0
1745+
end == Any[Int]
1746+
global isaint2
1747+
isaint2(a::Int) = true
1748+
isaint2(@nospecialize(_)) = false
1749+
@test Base.return_types((Any,)) do a
1750+
isaint2(a) && return a # a::Int
1751+
return 0
1752+
end == Any[Int]
1753+
global ispositive2
1754+
ispositive2(a::Int) = a > 0
1755+
ispositive2(@nospecialize(_)) = false
1756+
@test Base.return_types((Any,)) do a
1757+
ispositive2(a) && return a # a::Int
1758+
return 0
1759+
end == Any[Int]
1760+
1761+
# type constraints from multiple constant boolean return types
1762+
function f(x)
1763+
isa(x, Int) && return true
1764+
isa(x, Symbol) && return true
1765+
return false
1766+
end
1767+
@test Base.return_types((Any,)) do x
1768+
f(x) && return x # x::Union{Int,Symbol}
1769+
return nothing
1770+
end == Any[Union{Int,Symbol,Nothing}]
1771+
1772+
# constraint on non-vararg argument of `isva` method
1773+
isaint_isvapositive(a, va...) = isa(a, Int) && sum(va) > 0
1774+
@test Base.return_types((Any,Int,Int)) do a, b, c
1775+
isaint_isvapositive(a, b, c) && return a # a::Int
1776+
0
1777+
end == Any[Int]
1778+
1779+
# with Base functions
1780+
@test Base.return_types((Any,)) do a
1781+
Base.Fix2(isa, Int)(a) && return a # a::Int
1782+
return 0
1783+
end == Any[Int]
1784+
@test Base.return_types((Union{Nothing,Int},)) do a
1785+
isnothing(a) && return 0
1786+
return a # a::Int
1787+
end == Any[Int]
1788+
@test Base.return_types((Union{Missing,Int},)) do a
1789+
ismissing(a) && return 0
1790+
return a # a::Int
1791+
end == Any[Int]
1792+
@test Base.return_types((Any,)) do x
1793+
Meta.isexpr(x, :call) && return x # x::Expr
1794+
return nothing
1795+
end == Any[Union{Nothing,Expr}]
1796+
end
1797+
17251798
function f25579(g)
17261799
h = g[]
17271800
t = (h === nothing)

0 commit comments

Comments
 (0)