Skip to content

inference: implement an opt-in interface to cache generated sources #54916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 28, 2024

Conversation

aviatesk
Copy link
Member

In Cassette-like systems, where inference has to infer many calls of @generated function and the generated function involves complex code transformations, the overhead from code generation itself can become significant. This is because the results of code generation are not cached, leading to duplicated code generation in the following contexts:

  • method_for_inference_heuristics for regular inference on cached @generated function calls (since method_for_inference_limit_heuristics isn't stored in cached optimized sources, but is attached to generated unoptimized sources).
  • retrieval_code_info for constant propagation on cached @generated function calls.

Having said that, caching unoptimized sources generated by @generated functions is not a good tradeoff in general cases, considering the memory space consumed (and the image bloat). The code generation for generators like GeneratedFunctionStub produced by the front end is generally very simple, and the first duplicated code generation mentioned above does not occur for GeneratedFunctionStub.

So this unoptimized source caching should be enabled in an opt-in manner.

Based on this idea, this commit defines the trait abstract type Core.CachedGenerator as an interface for the external systems to opt-in. If the generator is a subtype of this trait, inference caches the generated unoptimized code, sacrificing memory space to improve the performance of subsequent inferences. Specifically, the mechanism for caching the unoptimized source uses the infrastructure already implemented in #54362. Thanks to #54362, the cache for generated functions is now partitioned by world age, so even if the unoptimized source is cached, the existing invalidation system will invalidate it as expected.

In JuliaDebug/CassetteOverlay.jl#56, the following benchmark results showed that approximately 1.5~3x inference speedup is achieved by opting into this feature:

Setup

using CassetteOverlay, BaseBenchmarks, BenchmarkTools

@MethodTable table;
pass = @overlaypass table;
BaseBenchmarks.load!("inference");
benchfunc1() = sin(42)
benchfunc2(xs, x) = findall(>(x), abs.(xs))
interp = BaseBenchmarks.InferenceBenchmarks.InferenceBenchmarker()

# benchmark inference on entire call graphs from scratch
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc1)
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc2, rand(10), 0.5)

# benchmark inference on the call graphs with most of them cached
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc1)
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc2, rand(10), 0.5)

Benchmark inference on entire call graphs from scratch

on master

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc1)
BenchmarkTools.Trial: 61 samples with 1 evaluation.
 Range (min … max):  78.574 ms … 87.653 ms  ┊ GC (min … max): 0.00% … 8.81%
 Time  (median):     83.149 ms              ┊ GC (median):    4.85%
 Time  (mean ± σ):   82.138 ms ±  2.366 ms  ┊ GC (mean ± σ):  3.36% ± 2.65%

  ▂ ▂▂ █     ▂                     █ ▅     ▅
  █▅██▅█▅▁█▁▁█▁▁▁▁▅▁▁▁▁▁▁▁▁▅▁▁▅██▅▅█████████▁█▁▅▁▁▁▁▁▁▁▁▁▁▁▁▅ ▁
  78.6 ms         Histogram: frequency by time        86.8 ms <

 Memory estimate: 52.32 MiB, allocs estimate: 1201192.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 4 samples with 1 evaluation.
 Range (min … max):  1.345 s …  1.369 s  ┊ GC (min … max): 2.45% … 3.39%
 Time  (median):     1.355 s             ┊ GC (median):    2.98%
 Time  (mean ± σ):   1.356 s ± 9.847 ms  ┊ GC (mean ± σ):  2.96% ± 0.41%

  █                   █     █                            █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  1.35 s        Histogram: frequency by time        1.37 s <

 Memory estimate: 637.96 MiB, allocs estimate: 15159639.

with this PR

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc1)
BenchmarkTools.Trial: 230 samples with 1 evaluation.
 Range (min … max):  19.339 ms … 82.521 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     19.938 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   21.665 ms ±  4.666 ms  ┊ GC (mean ± σ):  6.72% ± 8.80%

  ▃▇█▇▄                     ▂▂▃▃▄
  █████▇█▇▆▅▅▆▅▅▁▅▁▁▁▁▁▁▁▁▁██████▆▁█▁▅▇▆▁▅▁▁▅▁▅▁▁▁▁▁▁▅▁▁▁▁▁▁▅ ▆
  19.3 ms      Histogram: log(frequency) by time      29.4 ms <

 Memory estimate: 28.67 MiB, allocs estimate: 590138.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 14 samples with 1 evaluation.
 Range (min … max):  354.585 ms … 390.400 ms  ┊ GC (min … max): 0.00% … 7.01%
 Time  (median):     368.778 ms               ┊ GC (median):    3.74%
 Time  (mean ± σ):   368.824 ms ±   8.853 ms  ┊ GC (mean ± σ):  3.70% ± 1.89%

             ▃            █
  ▇▁▁▁▁▁▁▁▁▁▁█▁▇▇▁▁▁▁▇▁▁▁▁█▁▁▁▁▇▁▁▇▁▁▇▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
  355 ms           Histogram: frequency by time          390 ms <

 Memory estimate: 227.86 MiB, allocs estimate: 4689830.

Benchmark inference on the call graphs with most of them cached

on master

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc1)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  45.166 μs …  9.799 ms  ┊ GC (min … max): 0.00% … 98.96%
 Time  (median):     46.792 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   48.339 μs ± 97.539 μs  ┊ GC (mean ± σ):  2.01% ±  0.99%

    ▁▂▄▆▆▇███▇▆▅▄▃▄▄▂▂▂▁▁▁   ▁▁▂▂▁ ▁ ▂▁ ▁                     ▃
  ▃▇██████████████████████▇████████████████▇█▆▇▇▆▆▆▅▆▆▆▇▆▅▅▅▆ █
  45.2 μs      Histogram: log(frequency) by time        55 μs <

 Memory estimate: 25.27 KiB, allocs estimate: 614.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  303.375 μs …  16.582 ms  ┊ GC (min … max): 0.00% … 97.38%
 Time  (median):     317.625 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   338.772 μs ± 274.164 μs  ┊ GC (mean ± σ):  5.44% ±  7.56%

       ▃▆██▇▅▂▁
  ▂▂▄▅██████████▇▆▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▁▂▁▁▂▁▂▂ ▃
  303 μs           Histogram: frequency by time          394 μs <

 Memory estimate: 412.80 KiB, allocs estimate: 6224.

with this PR

@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc1)
BenchmarkTools.Trial: 10000 samples with 6 evaluations.
 Range (min … max):  5.444 μs …  1.808 ms  ┊ GC (min … max): 0.00% … 99.01%
 Time  (median):     5.694 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   6.228 μs ± 25.393 μs  ┊ GC (mean ± σ):  5.73% ±  1.40%

      ▄█▇▄
  ▁▂▄█████▇▄▃▃▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  5.44 μs        Histogram: frequency by time        7.47 μs <

 Memory estimate: 8.72 KiB, allocs estimate: 196.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  211.000 μs …  36.187 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     223.000 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   280.025 μs ± 750.097 μs  ┊ GC (mean ± σ):  6.86% ± 7.16%

  █▆▄▂▁                                                         ▁
  ███████▇▇▇▆▆▆▅▆▅▅▅▅▅▄▅▄▄▄▅▅▁▄▅▃▄▄▄▃▄▄▃▅▄▁▁▃▄▁▃▁▁▁▃▄▃▁▃▁▁▁▃▃▁▃ █
  211 μs        Histogram: log(frequency) by time       1.46 ms <

 Memory estimate: 374.17 KiB, allocs estimate: 5269.

@aviatesk aviatesk requested review from vtjnash and Keno June 24, 2024 16:06
@@ -138,8 +138,11 @@ use_const_api(li::CodeInstance) = invoke_api(li) == 2

function get_staged(mi::MethodInstance, world::UInt)
may_invoke_generator(mi) || return nothing
# enable caching of unoptimized generated code if the generator is `CachedGenerator`
cache_ci_ptr = (mi.def::Method).generator isa Core.CachedGenerator ?
pointer_from_objref(RefValue{CodeInstance}()) : C_NULL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is illegal. The RefValue is lost to the GC and the pointer is now considered leaked.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the best approach?
When this cache_ci_ptr is not null, a cache is generated and will be kept in the chain of mi->cache. Regardless, should this pointer be preserved at least during this ccall?

Base automatically changed from avi/clean-up-generated-function to master June 25, 2024 02:41
@aviatesk aviatesk force-pushed the avi/cache-unoptimized-generated branch 3 times, most recently from f57ff43 to a4479cd Compare June 26, 2024 06:59
In Cassette-like systems, where inference has to infer many calls of
`@generated` function and the generated function involves complex code
transformations, the overhead from code generation itself can become
significant. This is because the results of code generation are not
cached, leading to duplicated code generation in the following contexts:
- `method_for_inference_heuristics` for regular inference on cached
  `@generated` function calls (since `method_for_inference_limit_heuristics`
  isn't stored in cached optimized sources, but is attached to
  generated unoptimized sources).
- `retrieval_code_info` for constant propagation on cached `@generated`
  function calls.

Having said that, caching unoptimized sources generated by `@generated`
functions is not a good tradeoff in general cases, considering the
memory space consumed (and the image bloat). The code generation for
generators like `GeneratedFunctionStub` produced by the front end is
generally very simple, and the first duplicated code generation
mentioned above does not occur for `GeneratedFunctionStub`.

So this unoptimized source caching should be enabled in an opt-in manner.

Based on this idea, this commit defines the trait
`abstract type Core.CachedGenerator` as an interface for the external
systems to opt-in. If the generator is a subtype of this trait,
inference caches the generated unoptimized code, sacrificing memory
space to improve the performance of subsequent inferences.
Specifically, the mechanism for caching the unoptimized source uses the
infrastructure already implemented in #54362. Thanks to
#54362, the cache for generated functions is now
partitioned by world age, so even if the unoptimized source is cached,
the existing invalidation system will invalidate it as expected.

In JuliaDebug/CassetteOverlay.jl#56, the following benchmark results
showed that approximately 1.5~3x inference speedup is achieved by opting
into this feature:

\## Setup
```julia
using CassetteOverlay, BaseBenchmarks, BenchmarkTools

@MethodTable table;
pass = @overlaypass table;
BaseBenchmarks.load!("inference");
benchfunc1() = sin(42)
benchfunc2(xs, x) = findall(>(x), abs.(xs))
interp = BaseBenchmarks.InferenceBenchmarks.InferenceBenchmarker()

\# benchmark inference on entire call graphs from scratch
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc1)
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc2, rand(10), 0.5)

\# benchmark inference on the call graphs with most of them cached
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc1)
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc2, rand(10), 0.5)
```

\## Benchmark inference on entire call graphs from scratch
> on master
```
julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc1)
BenchmarkTools.Trial: 61 samples with 1 evaluation.
 Range (min … max):  78.574 ms … 87.653 ms  ┊ GC (min … max): 0.00% … 8.81%
 Time  (median):     83.149 ms              ┊ GC (median):    4.85%
 Time  (mean ± σ):   82.138 ms ±  2.366 ms  ┊ GC (mean ± σ):  3.36% ± 2.65%

  ▂ ▂▂ █     ▂                     █ ▅     ▅
  █▅██▅█▅▁█▁▁█▁▁▁▁▅▁▁▁▁▁▁▁▁▅▁▁▅██▅▅█████████▁█▁▅▁▁▁▁▁▁▁▁▁▁▁▁▅ ▁
  78.6 ms         Histogram: frequency by time        86.8 ms <

 Memory estimate: 52.32 MiB, allocs estimate: 1201192.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 4 samples with 1 evaluation.
 Range (min … max):  1.345 s …  1.369 s  ┊ GC (min … max): 2.45% … 3.39%
 Time  (median):     1.355 s             ┊ GC (median):    2.98%
 Time  (mean ± σ):   1.356 s ± 9.847 ms  ┊ GC (mean ± σ):  2.96% ± 0.41%

  █                   █     █                            █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  1.35 s        Histogram: frequency by time        1.37 s <

 Memory estimate: 637.96 MiB, allocs estimate: 15159639.
```
> with this PR
```
julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc1)
BenchmarkTools.Trial: 230 samples with 1 evaluation.
 Range (min … max):  19.339 ms … 82.521 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     19.938 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   21.665 ms ±  4.666 ms  ┊ GC (mean ± σ):  6.72% ± 8.80%

  ▃▇█▇▄                     ▂▂▃▃▄
  █████▇█▇▆▅▅▆▅▅▁▅▁▁▁▁▁▁▁▁▁██████▆▁█▁▅▇▆▁▅▁▁▅▁▅▁▁▁▁▁▁▅▁▁▁▁▁▁▅ ▆
  19.3 ms      Histogram: log(frequency) by time      29.4 ms <

 Memory estimate: 28.67 MiB, allocs estimate: 590138.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 14 samples with 1 evaluation.
 Range (min … max):  354.585 ms … 390.400 ms  ┊ GC (min … max): 0.00% … 7.01%
 Time  (median):     368.778 ms               ┊ GC (median):    3.74%
 Time  (mean ± σ):   368.824 ms ±   8.853 ms  ┊ GC (mean ± σ):  3.70% ± 1.89%

             ▃            █
  ▇▁▁▁▁▁▁▁▁▁▁█▁▇▇▁▁▁▁▇▁▁▁▁█▁▁▁▁▇▁▁▇▁▁▇▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
  355 ms           Histogram: frequency by time          390 ms <

 Memory estimate: 227.86 MiB, allocs estimate: 4689830.
```

\## Benchmark inference on the call graphs with most of them cached
> on master
```
julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc1)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  45.166 μs …  9.799 ms  ┊ GC (min … max): 0.00% … 98.96%
 Time  (median):     46.792 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   48.339 μs ± 97.539 μs  ┊ GC (mean ± σ):  2.01% ±  0.99%

    ▁▂▄▆▆▇███▇▆▅▄▃▄▄▂▂▂▁▁▁   ▁▁▂▂▁ ▁ ▂▁ ▁                     ▃
  ▃▇██████████████████████▇████████████████▇█▆▇▇▆▆▆▅▆▆▆▇▆▅▅▅▆ █
  45.2 μs      Histogram: log(frequency) by time        55 μs <

 Memory estimate: 25.27 KiB, allocs estimate: 614.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  303.375 μs …  16.582 ms  ┊ GC (min … max): 0.00% … 97.38%
 Time  (median):     317.625 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   338.772 μs ± 274.164 μs  ┊ GC (mean ± σ):  5.44% ±  7.56%

       ▃▆██▇▅▂▁
  ▂▂▄▅██████████▇▆▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▁▂▁▁▂▁▂▂ ▃
  303 μs           Histogram: frequency by time          394 μs <

 Memory estimate: 412.80 KiB, allocs estimate: 6224.
```
> with this PR
```
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc1)
BenchmarkTools.Trial: 10000 samples with 6 evaluations.
 Range (min … max):  5.444 μs …  1.808 ms  ┊ GC (min … max): 0.00% … 99.01%
 Time  (median):     5.694 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   6.228 μs ± 25.393 μs  ┊ GC (mean ± σ):  5.73% ±  1.40%

      ▄█▇▄
  ▁▂▄█████▇▄▃▃▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  5.44 μs        Histogram: frequency by time        7.47 μs <

 Memory estimate: 8.72 KiB, allocs estimate: 196.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  211.000 μs …  36.187 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     223.000 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   280.025 μs ± 750.097 μs  ┊ GC (mean ± σ):  6.86% ± 7.16%

  █▆▄▂▁                                                         ▁
  ███████▇▇▇▆▆▆▅▆▅▅▅▅▅▄▅▄▄▄▅▅▁▄▅▃▄▄▄▃▄▄▃▅▄▁▁▃▄▁▃▁▁▁▃▄▃▁▃▁▁▁▃▃▁▃ █
  211 μs        Histogram: log(frequency) by time       1.46 ms <

 Memory estimate: 374.17 KiB, allocs estimate: 5269.
```
@aviatesk aviatesk force-pushed the avi/cache-unoptimized-generated branch from a4479cd to 426279c Compare June 27, 2024 10:35
@aviatesk aviatesk force-pushed the avi/cache-unoptimized-generated branch from 426279c to c5b985e Compare June 27, 2024 10:39
@aviatesk
Copy link
Member Author

I've added tests and confirmed that it works well in real-world packages like CassetteOverlay.jl. If there are no objections, I plan to merge it soon.

@aviatesk aviatesk merged commit b5d0b90 into master Jun 28, 2024
7 checks passed
@aviatesk aviatesk deleted the avi/cache-unoptimized-generated branch June 28, 2024 01:08
Comment on lines +153 to +154
cache_ci_ptr = pointer_from_objref(cache_ci)
src = ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{CodeInstance}), mi, world, cache_ci_ptr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dangerously unsafe for the GC

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants