Skip to content

Commit 9d2e9ed

Browse files
authored
Support adding CodeInstances to JIT for interpreters defining a codegen cache (#57272)
Implements a way to add `CodeInstance`s compiled by external interpreters to JIT, such that they become legal targets for `invoke` calls. Based on a design proposed by @Keno, the `AbstractInterpreter` interface is extended to support providing a codegen cache that is filled during inference for future use with `add_codeinsts_to_jit!`. This allows `invoke(f, ::CodeInstance, args...)` to work on external interpreters, which is currently failing on `master` (see #57193). --------- Co-authored-by: Cédric Belmant <[email protected]>
1 parent 7c89aba commit 9d2e9ed

File tree

4 files changed

+86
-37
lines changed

4 files changed

+86
-37
lines changed

Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ struct SplitCacheInterp <: Compiler.AbstractInterpreter
99
inf_params::Compiler.InferenceParams
1010
opt_params::Compiler.OptimizationParams
1111
inf_cache::Vector{Compiler.InferenceResult}
12+
codegen_cache::IdDict{CodeInstance,CodeInfo}
1213
function SplitCacheInterp(;
1314
world::UInt = Base.get_world_counter(),
1415
inf_params::Compiler.InferenceParams = Compiler.InferenceParams(),
1516
opt_params::Compiler.OptimizationParams = Compiler.OptimizationParams(),
1617
inf_cache::Vector{Compiler.InferenceResult} = Compiler.InferenceResult[])
17-
new(world, inf_params, opt_params, inf_cache)
18+
new(world, inf_params, opt_params, inf_cache, IdDict{CodeInstance,CodeInfo}())
1819
end
1920
end
2021

@@ -23,10 +24,11 @@ Compiler.OptimizationParams(interp::SplitCacheInterp) = interp.opt_params
2324
Compiler.get_inference_world(interp::SplitCacheInterp) = interp.world
2425
Compiler.get_inference_cache(interp::SplitCacheInterp) = interp.inf_cache
2526
Compiler.cache_owner(::SplitCacheInterp) = SplitCacheOwner()
27+
Compiler.codegen_cache(interp::SplitCacheInterp) = interp.codegen_cache
2628

2729
import Core.OptimizedGenerics.CompilerPlugins: typeinf, typeinf_edge
2830
@eval @noinline typeinf(::SplitCacheOwner, mi::MethodInstance, source_mode::UInt8) =
29-
Base.invoke_in_world(which(typeinf, Tuple{SplitCacheOwner, MethodInstance, UInt8}).primary_world, Compiler.typeinf_ext, SplitCacheInterp(; world=Base.tls_world_age()), mi, source_mode)
31+
Base.invoke_in_world(which(typeinf, Tuple{SplitCacheOwner, MethodInstance, UInt8}).primary_world, Compiler.typeinf_ext_toplevel, SplitCacheInterp(; world=Base.tls_world_age()), mi, source_mode)
3032

3133
@eval @noinline function typeinf_edge(::SplitCacheOwner, mi::MethodInstance, parent_frame::Compiler.InferenceState, world::UInt, source_mode::UInt8)
3234
# TODO: This isn't quite right, we're just sketching things for now

Compiler/src/typeinfer.jl

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,10 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState, validation
143143
ci, inferred_result, const_flag, first(result.valid_worlds), last(result.valid_worlds), encode_effects(result.ipo_effects),
144144
result.analysis_results, di, edges)
145145
engine_reject(interp, ci)
146-
if !discard_src && isdefined(interp, :codegen) && uncompressed isa CodeInfo
146+
codegen = codegen_cache(interp)
147+
if !discard_src && codegen !== nothing && uncompressed isa CodeInfo
147148
# record that the caller could use this result to generate code when required, if desired, to avoid repeating n^2 work
148-
interp.codegen[ci] = uncompressed
149+
codegen[ci] = uncompressed
149150
if bootstrapping_compiler && inferred_result == nothing
150151
# This is necessary to get decent bootstrapping performance
151152
# when compiling the compiler to inject everything eagerly
@@ -185,8 +186,9 @@ function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstan
185186
ccall(:jl_update_codeinst, Cvoid, (Any, Any, Int32, UInt, UInt, UInt32, Any, Any, Any),
186187
ci, nothing, const_flag, min_world, max_world, ipo_effects, nothing, di, edges)
187188
code_cache(interp)[mi] = ci
188-
if isdefined(interp, :codegen)
189-
interp.codegen[ci] = src
189+
codegen = codegen_cache(interp)
190+
if codegen !== nothing
191+
codegen[ci] = src
190192
end
191193
engine_reject(interp, ci)
192194
return nothing
@@ -1168,7 +1170,10 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance, source_mod
11681170

11691171
ci = result.ci # reload from result in case it changed
11701172
@assert frame.cache_mode != CACHE_MODE_NULL
1171-
@assert is_result_constabi_eligible(result) || (!isdefined(interp, :codegen) || haskey(interp.codegen, ci))
1173+
@assert is_result_constabi_eligible(result) || begin
1174+
codegen = codegen_cache(interp)
1175+
codegen === nothing || haskey(codegen, ci)
1176+
end
11721177
@assert is_result_constabi_eligible(result) == use_const_api(ci)
11731178
@assert isdefined(ci, :inferred) "interpreter did not fulfill our expectations"
11741179
if !is_cached(frame) && source_mode == SOURCE_MODE_ABI
@@ -1234,44 +1239,55 @@ function collectinvokes!(wq::Vector{CodeInstance}, ci::CodeInfo)
12341239
end
12351240
end
12361241

1237-
# This is a bridge for the C code calling `jl_typeinf_func()` on a single Method match
1238-
function typeinf_ext_toplevel(mi::MethodInstance, world::UInt, source_mode::UInt8)
1239-
interp = NativeInterpreter(world)
1240-
ci = typeinf_ext(interp, mi, source_mode)
1241-
if source_mode == SOURCE_MODE_ABI && ci isa CodeInstance && !ci_has_invoke(ci)
1242-
inspected = IdSet{CodeInstance}()
1243-
tocompile = Vector{CodeInstance}()
1244-
push!(tocompile, ci)
1245-
while !isempty(tocompile)
1246-
# ci_has_real_invoke(ci) && return ci # optimization: cease looping if ci happens to get compiled (not just jl_fptr_wait_for_compiled, but fully jl_is_compiled_codeinst)
1247-
callee = pop!(tocompile)
1248-
ci_has_invoke(callee) && continue
1249-
callee in inspected && continue
1250-
src = get(interp.codegen, callee, nothing)
1242+
function add_codeinsts_to_jit!(interp::AbstractInterpreter, ci, source_mode::UInt8)
1243+
source_mode == SOURCE_MODE_ABI || return ci
1244+
ci isa CodeInstance && !ci_has_invoke(ci) || return ci
1245+
codegen = codegen_cache(interp)
1246+
codegen === nothing && return ci
1247+
inspected = IdSet{CodeInstance}()
1248+
tocompile = Vector{CodeInstance}()
1249+
push!(tocompile, ci)
1250+
while !isempty(tocompile)
1251+
# ci_has_real_invoke(ci) && return ci # optimization: cease looping if ci happens to get compiled (not just jl_fptr_wait_for_compiled, but fully jl_is_compiled_codeinst)
1252+
callee = pop!(tocompile)
1253+
ci_has_invoke(callee) && continue
1254+
callee in inspected && continue
1255+
src = get(codegen, callee, nothing)
1256+
if !isa(src, CodeInfo)
1257+
src = @atomic :monotonic callee.inferred
1258+
if isa(src, String)
1259+
src = _uncompressed_ir(callee, src)
1260+
end
12511261
if !isa(src, CodeInfo)
1252-
src = @atomic :monotonic callee.inferred
1253-
if isa(src, String)
1254-
src = _uncompressed_ir(callee, src)
1255-
end
1256-
if !isa(src, CodeInfo)
1257-
newcallee = typeinf_ext(interp, callee.def, source_mode)
1258-
if newcallee isa CodeInstance
1259-
callee === ci && (ci = newcallee) # ci stopped meeting the requirements after typeinf_ext last checked, try again with newcallee
1260-
push!(tocompile, newcallee)
1261-
#else
1262-
# println("warning: could not get source code for ", callee.def)
1263-
end
1264-
continue
1262+
newcallee = typeinf_ext(interp, callee.def, source_mode)
1263+
if newcallee isa CodeInstance
1264+
callee === ci && (ci = newcallee) # ci stopped meeting the requirements after typeinf_ext last checked, try again with newcallee
1265+
push!(tocompile, newcallee)
1266+
#else
1267+
# println("warning: could not get source code for ", callee.def)
12651268
end
1269+
continue
12661270
end
1267-
push!(inspected, callee)
1268-
collectinvokes!(tocompile, src)
1269-
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), callee, src)
12701271
end
1272+
push!(inspected, callee)
1273+
collectinvokes!(tocompile, src)
1274+
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), callee, src)
12711275
end
12721276
return ci
12731277
end
12741278

1279+
function typeinf_ext_toplevel(interp::AbstractInterpreter, mi::MethodInstance, source_mode::UInt8)
1280+
ci = typeinf_ext(interp, mi, source_mode)
1281+
ci = add_codeinsts_to_jit!(interp, ci, source_mode)
1282+
return ci
1283+
end
1284+
1285+
# This is a bridge for the C code calling `jl_typeinf_func()` on a single Method match
1286+
function typeinf_ext_toplevel(mi::MethodInstance, world::UInt, source_mode::UInt8)
1287+
interp = NativeInterpreter(world)
1288+
return typeinf_ext_toplevel(interp, mi, source_mode)
1289+
end
1290+
12751291
# This is a bridge for the C code calling `jl_typeinf_func()` on set of Method matches
12761292
function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::Bool)
12771293
inspected = IdSet{CodeInstance}()

Compiler/src/types.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ the following methods to satisfy the `AbstractInterpreter` API requirement:
2323
- `get_inference_world(interp::NewInterpreter)` - return the world age for this interpreter
2424
- `get_inference_cache(interp::NewInterpreter)` - return the local inference cache
2525
- `cache_owner(interp::NewInterpreter)` - return the owner of any new cache entries
26+
27+
If `CodeInstance`s compiled using `interp::NewInterpreter` are meant to be executed with `invoke`,
28+
a method `codegen_cache(interp::NewInterpreter) -> IdDict{CodeInstance, CodeInfo}` must be defined,
29+
and inference must be triggered via `typeinf_ext_toplevel` with source mode `SOURCE_MODE_ABI`.
2630
"""
2731
abstract type AbstractInterpreter end
2832

@@ -430,6 +434,19 @@ to incorporate customized dispatches for the overridden methods.
430434
method_table(interp::AbstractInterpreter) = InternalMethodTable(get_inference_world(interp))
431435
method_table(interp::NativeInterpreter) = interp.method_table
432436

437+
"""
438+
codegen_cache(interp::AbstractInterpreter) -> Union{Nothing, IdDict{CodeInstance, CodeInfo}}
439+
440+
Optionally return a cache associating a `CodeInfo` to a `CodeInstance` that should be added to the JIT
441+
for future execution via `invoke(f, ::CodeInstance, args...)`. This cache is used during `typeinf_ext_toplevel`,
442+
and may be safely discarded between calls to this function.
443+
444+
By default, a value of `nothing` is returned indicating that `CodeInstance`s should not be added to the JIT.
445+
Attempting to execute them via `invoke` will result in an error.
446+
"""
447+
codegen_cache(interp::AbstractInterpreter) = nothing
448+
codegen_cache(interp::NativeInterpreter) = interp.codegen
449+
433450
"""
434451
By default `AbstractInterpreter` implements the following inference bail out logic:
435452
- `bail_out_toplevel_call(::AbstractInterpreter, sig, ::InferenceState)`: bail out from

Compiler/test/AbstractInterpreter.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,3 +534,17 @@ let interp = DebugInterp()
534534
end
535535
@test found
536536
end
537+
538+
@newinterp InvokeInterp
539+
struct InvokeOwner end
540+
codegen = IdDict{CodeInstance, CodeInfo}()
541+
Compiler.cache_owner(::InvokeInterp) = InvokeOwner()
542+
Compiler.codegen_cache(::InvokeInterp) = codegen
543+
let interp = InvokeInterp()
544+
source_mode = Compiler.SOURCE_MODE_ABI
545+
f = (+)
546+
args = (1, 1)
547+
mi = @ccall jl_method_lookup(Any[f, args...]::Ptr{Any}, (1+length(args))::Csize_t, Base.tls_world_age()::Csize_t)::Ref{Core.MethodInstance}
548+
ci = Compiler.typeinf_ext_toplevel(interp, mi, source_mode)
549+
@test invoke(f, ci, args...) == 2
550+
end

0 commit comments

Comments
 (0)