Skip to content

infer binding types by use #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/StaticLint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,12 @@ function (state::Delayed)(x::EXPR)
resolve_ref(x, state)

traverse(x, state)

state.scope != s0 && (state.scope = s0)
if state.scope != s0
for (k,b) in state.scope.names
infer_type_by_use(b, state.server)
end
(state.scope = s0)
end
return state.scope
end

Expand Down
5 changes: 5 additions & 0 deletions src/linting/checks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,11 @@ function should_mark_missing_getfield_ref(x, server)
if lhsref.type isa SymbolServer.DataTypeStore && !(isempty(lhsref.type.fieldnames) || isunionfaketype(lhsref.type.name) || has_getproperty_method(lhsref.type, server))
return true
elseif lhsref.type isa Binding && lhsref.type.val isa EXPR && CSTParser.defines_struct(lhsref.type.val) && !has_getproperty_method(lhsref.type)
# We may have infered the lhs type after the semantic pass that was resolving references. Copied from `resolve_getfield(x::EXPR, parent_type::EXPR, state::State)::Bool`.
if scopehasbinding(scopeof(lhsref.type.val), valof(x))
setref!(x, scopeof(lhsref.type.val).names[valof(x)])
return false
end
return true
end
end
Expand Down
3 changes: 3 additions & 0 deletions src/server.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ function semantic_pass(file, target=nothing)
for x in state.delayed
if hasscope(x)
traverse(x, Delayed(scopeof(x), server))
for (k,b) in scopeof(x).names
infer_type_by_use(b, state.server)
end
else
ds = retrieve_delayed_scope(x)
traverse(x, Delayed(ds, server))
Expand Down
114 changes: 114 additions & 0 deletions src/type_inf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,117 @@ function infer_type(binding::Binding, scope, state)
end
end
end

# Work out what type a bound variable has by functions that are called on it.
function infer_type_by_use(b::Binding, server)
b.type !== nothing && return # b already has a type
possibletypes = []
visitedmethods = []
for ref in b.refs
new_possibles = []
ref isa EXPR || continue # skip non-EXPR (i.e. used for handling of globals)
check_ref_against_calls(ref, visitedmethods, new_possibles, server)

if isempty(possibletypes)
possibletypes = new_possibles
elseif !isempty(new_possibles)
possibletypes = intersect(possibletypes, new_possibles)
if isempty(possibletypes)
return
end
end
end
# Only do something if we're left with a singleton set at the end.
if length(possibletypes) == 1
type = first(possibletypes)

if type isa Binding
b.type = type
elseif type isa SymbolServer.DataTypeStore
b.type = type
elseif type isa SymbolServer.VarRef
b.type = SymbolServer._lookup(type, getsymbolserver(server)) # could be nothing
elseif type isa SymbolServer.FakeTypeName && isempty(type.parameters)
b.type = SymbolServer._lookup(type.name, getsymbolserver(server)) # could be nothing
end
end
end

function check_ref_against_calls(x, visitedmethods, new_possibles, server)
if is_arg_of_resolved_call(x)
sig = parentof(x)
# x is argument of function call (func) and we know what that function is
if CSTParser.isidentifier(sig.args[1])
func = refof(sig.args[1])
else
func = refof(sig.args[1].args[2].args[1])
end
# make sure we've got the last binding for func
if func isa Binding
func = get_last_method(func, server)
end
# what slot does ref sit in?
argi = get_arg_position_in_call(sig, x)
tls = retrieve_toplevel_scope(x)
while (func isa Binding && func.type == CoreTypes.Function) || func isa SymbolServer.SymStore
!(func in visitedmethods) ? push!(visitedmethods, func) : return # check whether we've been here before
if func isa Binding
get_arg_type_at_position(func, argi, new_possibles)
func = func.prev
else
tls === nothing && return
iterate_over_ss_methods(func, tls, server, m->(get_arg_type_at_position(m, argi, new_possibles);false))
return
end
end
end
end

function is_arg_of_resolved_call(x::EXPR)
parentof(x) isa EXPR && headof(parentof(x)) === :call && # check we're in a call signature
(caller = parentof(x).args[1]) !== x && # and that x is not the caller
(hasref(caller) || (is_getfield(caller) && headof(caller.args[2]) === :quotenode && hasref(caller.args[2].args[1])))
end

function get_arg_position_in_call(sig::EXPR, arg)
for i in 1:length(sig.args)
sig.args[i] == arg && return i
end
end

function get_arg_type_at_position(b::Binding, argi, types)
if b.val isa EXPR
sig = CSTParser.get_sig(b.val)
if sig !== nothing &&
sig.args !== nothing && argi <= length(sig.args) &&
hasbinding(sig.args[argi]) &&
(argb = bindingof(sig.args[argi]); argb isa Binding && argb.type !== nothing) &&
!(argb.type in types)
push!(types, argb.type)
return
end
elseif b.val isa SymbolServer.DataTypeStore || b.val isa SymbolServer.FunctionStore
for m in b.val.methods
get_arg_type_at_position(m, argi, types)
end
end
return
end

function get_arg_type_at_position(m::SymbolServer.MethodStore, argi, types)
if length(m.sig) >= argi && m.sig[argi][2] != SymbolServer.VarRef(SymbolServer.VarRef(nothing, :Core), :Any) && !(m.sig[argi][2] in types)
push!(types, m.sig[argi][2])
end
end

function get_last_method(b::Binding, server, visited_bindings = Binding[])
if b.next === nothing || b == b.next || !(b.next isa Binding) || b in visited_bindings
return b
end
push!(visited_bindings, b)
if b.type == b.next.type == CoreTypes.Function
return get_last_method(b.next, server, visited_bindings)
else
return b
end
end
43 changes: 43 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1417,4 +1417,47 @@ f(arg) = arg
@test isempty(StaticLint.collect_hints(cst, server))
end
end
@testset "type inference by use" begin
cst = parse_and_pass("""
f(x::String) = true
function g(x)
f(x)
end""")
@test bindingof(cst.args[2].args[1].args[2]).type !== nothing

cst = parse_and_pass("""
f(x::String) = true
f(x::Char) = true
function g(x)
f(x)
end""")
@test bindingof(cst.args[3].args[1].args[2]).type === nothing

cst = parse_and_pass("""
f(x::String) = true
f1(x::String) = true
function g(x)
f(x)
f1(x)
end""")
@test bindingof(cst.args[3].args[1].args[2]).type !== nothing

cst = parse_and_pass("""
f(x::String) = true
f1(x::Char) = true
function g(x)
f(x)
f1(x)
end""")
@test bindingof(cst.args[3].args[1].args[2]).type === nothing

cst = parse_and_pass("""
f(x::String) = true
f1(x) = true
function g(x)
f(x)
f1(x)
end""")
@test bindingof(cst.args[3].args[1].args[2]).type !== nothing
end
end