Skip to content

Commit 89271dc

Browse files
authored
staticdata: Memoize type_in_worklist query (#57917)
When pre-compiling `stdlib/` this cache has a 91% hit rate, so this seems fairly profitable. It also dramatically improves some pathological cases, a few of which have been hit in the wild (arguably due to inference bugs) Without this PR, this package takes exponentially long to pre-compile: ```julia function BigType(N) (N == 0) && return Nothing T = BigType(N-1) return Pair{T,T} end foo(::Type{T}) where T = T precompile(foo, (Type{BigType(40)},)) ``` For an in-the-wild test case hit by a customer, this reduces pre-compilation time from over an hour to just ~two and a half minutes. Resolves #53331.
1 parent fe613d4 commit 89271dc

File tree

2 files changed

+94
-52
lines changed

2 files changed

+94
-52
lines changed

src/staticdata.c

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,22 @@ static const size_t WORLD_AGE_REVALIDATION_SENTINEL = 0x1;
9090
JL_DLLEXPORT size_t jl_require_world = ~(size_t)0;
9191
JL_DLLEXPORT _Atomic(size_t) jl_first_image_replacement_world = ~(size_t)0;
9292

93+
// This structure is used to store hash tables for the memoization
94+
// of queries in staticdata.c (currently only `type_in_worklist`).
95+
typedef struct {
96+
htable_t type_in_worklist;
97+
} jl_query_cache;
98+
99+
static void init_query_cache(jl_query_cache *cache)
100+
{
101+
htable_new(&cache->type_in_worklist, 0);
102+
}
103+
104+
static void destroy_query_cache(jl_query_cache *cache)
105+
{
106+
htable_free(&cache->type_in_worklist);
107+
}
108+
93109
#include "staticdata_utils.c"
94110
#include "precompile_utils.c"
95111

@@ -552,6 +568,7 @@ typedef struct {
552568
jl_array_t *method_roots_list;
553569
htable_t method_roots_index;
554570
uint64_t worklist_key;
571+
jl_query_cache *query_cache;
555572
jl_ptls_t ptls;
556573
jl_image_t *image;
557574
int8_t incremental;
@@ -675,14 +692,13 @@ static int jl_needs_serialization(jl_serializer_state *s, jl_value_t *v) JL_NOTS
675692
return 1;
676693
}
677694

678-
679-
static int caching_tag(jl_value_t *v) JL_NOTSAFEPOINT
695+
static int caching_tag(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
680696
{
681697
if (jl_is_method_instance(v)) {
682698
jl_method_instance_t *mi = (jl_method_instance_t*)v;
683699
jl_value_t *m = mi->def.value;
684700
if (jl_is_method(m) && jl_object_in_image(m))
685-
return 1 + type_in_worklist(mi->specTypes);
701+
return 1 + type_in_worklist(mi->specTypes, query_cache);
686702
}
687703
if (jl_is_binding(v)) {
688704
jl_globalref_t *gr = ((jl_binding_t*)v)->globalref;
@@ -697,24 +713,24 @@ static int caching_tag(jl_value_t *v) JL_NOTSAFEPOINT
697713
if (jl_is_tuple_type(dt) ? !dt->isconcretetype : dt->hasfreetypevars)
698714
return 0; // aka !is_cacheable from jltypes.c
699715
if (jl_object_in_image((jl_value_t*)dt->name))
700-
return 1 + type_in_worklist(v);
716+
return 1 + type_in_worklist(v, query_cache);
701717
}
702718
jl_value_t *dtv = jl_typeof(v);
703719
if (jl_is_datatype_singleton((jl_datatype_t*)dtv)) {
704-
return 1 - type_in_worklist(dtv); // these are already recached in the datatype in the image
720+
return 1 - type_in_worklist(dtv, query_cache); // these are already recached in the datatype in the image
705721
}
706722
return 0;
707723
}
708724

709-
static int needs_recaching(jl_value_t *v) JL_NOTSAFEPOINT
725+
static int needs_recaching(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
710726
{
711-
return caching_tag(v) == 2;
727+
return caching_tag(v, query_cache) == 2;
712728
}
713729

714-
static int needs_uniquing(jl_value_t *v) JL_NOTSAFEPOINT
730+
static int needs_uniquing(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
715731
{
716732
assert(!jl_object_in_image(v));
717-
return caching_tag(v) == 1;
733+
return caching_tag(v, query_cache) == 1;
718734
}
719735

720736
static void record_field_change(jl_value_t **addr, jl_value_t *newval) JL_NOTSAFEPOINT
@@ -839,7 +855,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
839855
jl_datatype_t *dt = (jl_datatype_t*)v;
840856
// ensure all type parameters are recached
841857
jl_queue_for_serialization_(s, (jl_value_t*)dt->parameters, 1, 1);
842-
if (jl_is_datatype_singleton(dt) && needs_uniquing(dt->instance)) {
858+
if (jl_is_datatype_singleton(dt) && needs_uniquing(dt->instance, s->query_cache)) {
843859
assert(jl_needs_serialization(s, dt->instance)); // should be true, since we visited dt
844860
// do not visit dt->instance for our template object as it leads to unwanted cycles here
845861
// (it may get serialized from elsewhere though)
@@ -850,7 +866,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
850866
if (s->incremental && jl_is_method_instance(v)) {
851867
jl_method_instance_t *mi = (jl_method_instance_t*)v;
852868
jl_value_t *def = mi->def.value;
853-
if (needs_uniquing(v)) {
869+
if (needs_uniquing(v, s->query_cache)) {
854870
// we only need 3 specific fields of this (the rest are not used)
855871
jl_queue_for_serialization(s, mi->def.value);
856872
jl_queue_for_serialization(s, mi->specTypes);
@@ -865,7 +881,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
865881
record_field_change((jl_value_t**)&mi->cache, NULL);
866882
}
867883
else {
868-
assert(!needs_recaching(v));
884+
assert(!needs_recaching(v, s->query_cache));
869885
}
870886
// n.b. opaque closures cannot be inspected and relied upon like a
871887
// normal method since they can get improperly introduced by generated
@@ -875,7 +891,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
875891
// error now.
876892
}
877893
if (s->incremental && jl_is_binding(v)) {
878-
if (needs_uniquing(v)) {
894+
if (needs_uniquing(v, s->query_cache)) {
879895
jl_binding_t *b = (jl_binding_t*)v;
880896
jl_queue_for_serialization(s, b->globalref->mod);
881897
jl_queue_for_serialization(s, b->globalref->name);
@@ -1102,9 +1118,9 @@ static void jl_queue_for_serialization_(jl_serializer_state *s, jl_value_t *v, i
11021118
// Items that require postorder traversal must visit their children prior to insertion into
11031119
// the worklist/serialization_order (and also before their first use)
11041120
if (s->incremental && !immediate) {
1105-
if (jl_is_datatype(t) && needs_uniquing(v))
1121+
if (jl_is_datatype(t) && needs_uniquing(v, s->query_cache))
11061122
immediate = 1;
1107-
if (jl_is_datatype_singleton((jl_datatype_t*)t) && needs_uniquing(v))
1123+
if (jl_is_datatype_singleton((jl_datatype_t*)t) && needs_uniquing(v, s->query_cache))
11081124
immediate = 1;
11091125
}
11101126

@@ -1267,7 +1283,7 @@ static uintptr_t _backref_id(jl_serializer_state *s, jl_value_t *v, jl_array_t *
12671283

12681284
static void record_uniquing(jl_serializer_state *s, jl_value_t *fld, uintptr_t offset) JL_NOTSAFEPOINT
12691285
{
1270-
if (s->incremental && jl_needs_serialization(s, fld) && needs_uniquing(fld)) {
1286+
if (s->incremental && jl_needs_serialization(s, fld) && needs_uniquing(fld, s->query_cache)) {
12711287
if (jl_is_datatype(fld) || jl_is_datatype_singleton((jl_datatype_t*)jl_typeof(fld)))
12721288
arraylist_push(&s->uniquing_types, (void*)(uintptr_t)offset);
12731289
else if (jl_is_method_instance(fld) || jl_is_binding(fld))
@@ -1491,7 +1507,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
14911507
// write header
14921508
if (object_id_expected)
14931509
write_uint(f, jl_object_id(v));
1494-
if (s->incremental && jl_needs_serialization(s, (jl_value_t*)t) && needs_uniquing((jl_value_t*)t))
1510+
if (s->incremental && jl_needs_serialization(s, (jl_value_t*)t) && needs_uniquing((jl_value_t*)t, s->query_cache))
14951511
arraylist_push(&s->uniquing_types, (void*)(uintptr_t)(ios_pos(f)|1));
14961512
if (f == s->const_data)
14971513
write_uint(s->const_data, ((uintptr_t)t->smalltag << 4) | GC_OLD_MARKED | GC_IN_IMAGE);
@@ -1502,7 +1518,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
15021518
layout_table.items[item] = (void*)(reloc_offset | (f == s->const_data)); // store the inverse mapping of `serialization_order` (`id` => object-as-streampos)
15031519

15041520
if (s->incremental) {
1505-
if (needs_uniquing(v)) {
1521+
if (needs_uniquing(v, s->query_cache)) {
15061522
if (jl_typetagis(v, jl_binding_type)) {
15071523
jl_binding_t *b = (jl_binding_t*)v;
15081524
if (b->globalref == NULL)
@@ -1531,7 +1547,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
15311547
assert(jl_is_datatype_singleton(t) && "unreachable");
15321548
}
15331549
}
1534-
else if (needs_recaching(v)) {
1550+
else if (needs_recaching(v, s->query_cache)) {
15351551
arraylist_push(jl_is_datatype(v) ? &s->fixup_types : &s->fixup_objs, (void*)reloc_offset);
15361552
}
15371553
}
@@ -1964,7 +1980,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
19641980
}
19651981
}
19661982
void *superidx = ptrhash_get(&serialization_order, dt->super);
1967-
if (s->incremental && superidx != HT_NOTFOUND && from_seroder_entry(superidx) > item && needs_uniquing((jl_value_t*)dt->super))
1983+
if (s->incremental && superidx != HT_NOTFOUND && from_seroder_entry(superidx) > item && needs_uniquing((jl_value_t*)dt->super, s->query_cache))
19681984
arraylist_push(&s->uniquing_super, dt->super);
19691985
}
19701986
else if (jl_is_typename(v)) {
@@ -2875,13 +2891,14 @@ JL_DLLEXPORT jl_value_t *jl_as_global_root(jl_value_t *val, int insert)
28752891
static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *newly_inferred,
28762892
/* outputs */ jl_array_t **extext_methods JL_REQUIRE_ROOTED_SLOT,
28772893
jl_array_t **new_ext_cis JL_REQUIRE_ROOTED_SLOT,
2878-
jl_array_t **edges JL_REQUIRE_ROOTED_SLOT)
2894+
jl_array_t **edges JL_REQUIRE_ROOTED_SLOT,
2895+
jl_query_cache *query_cache)
28792896
{
28802897
// extext_methods: [method1, ...], worklist-owned "extending external" methods added to functions owned by modules outside the worklist
28812898
// edges: [caller1, ext_targets, ...] for worklist-owned methods calling external methods
28822899

28832900
// Save the inferred code from newly inferred, external methods
2884-
*new_ext_cis = queue_external_cis(newly_inferred);
2901+
*new_ext_cis = queue_external_cis(newly_inferred, query_cache);
28852902

28862903
// Collect method extensions and edges data
28872904
*extext_methods = jl_alloc_vec_any(0);
@@ -2911,7 +2928,8 @@ static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *new
29112928
// In addition to the system image (where `worklist = NULL`), this can also save incremental images with external linkage
29122929
static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
29132930
jl_array_t *worklist, jl_array_t *extext_methods,
2914-
jl_array_t *new_ext_cis, jl_array_t *edges)
2931+
jl_array_t *new_ext_cis, jl_array_t *edges,
2932+
jl_query_cache *query_cache)
29152933
{
29162934
htable_new(&field_replace, 0);
29172935
htable_new(&bits_replace, 0);
@@ -3018,6 +3036,7 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
30183036
ios_mem(&gvar_record, 0);
30193037
ios_mem(&fptr_record, 0);
30203038
jl_serializer_state s = {0};
3039+
s.query_cache = query_cache;
30213040
s.incremental = !(worklist == NULL);
30223041
s.s = &sysimg;
30233042
s.const_data = &const_data;
@@ -3375,11 +3394,14 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
33753394
int64_t datastartpos = 0;
33763395
JL_GC_PUSH4(&mod_array, &extext_methods, &new_ext_cis, &edges);
33773396

3397+
jl_query_cache query_cache;
3398+
init_query_cache(&query_cache);
3399+
33783400
if (worklist) {
33793401
mod_array = jl_get_loaded_modules(); // __toplevel__ modules loaded in this session (from Base.loaded_modules_array)
33803402
// Generate _native_data`
33813403
if (_native_data != NULL) {
3382-
jl_prepare_serialization_data(mod_array, newly_inferred, &extext_methods, &new_ext_cis, NULL);
3404+
jl_prepare_serialization_data(mod_array, newly_inferred, &extext_methods, &new_ext_cis, NULL, &query_cache);
33833405
jl_precompile_toplevel_module = (jl_module_t*)jl_array_ptr_ref(worklist, jl_array_len(worklist)-1);
33843406
*_native_data = jl_precompile_worklist(worklist, extext_methods, new_ext_cis);
33853407
jl_precompile_toplevel_module = NULL;
@@ -3410,7 +3432,7 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
34103432
assert((ct->reentrant_timing & 0b1110) == 0);
34113433
ct->reentrant_timing |= 0b1000;
34123434
if (worklist) {
3413-
jl_prepare_serialization_data(mod_array, newly_inferred, &extext_methods, &new_ext_cis, &edges);
3435+
jl_prepare_serialization_data(mod_array, newly_inferred, &extext_methods, &new_ext_cis, &edges, &query_cache);
34143436
if (!emit_split) {
34153437
write_int32(f, 0); // No clone_targets
34163438
write_padding(f, LLT_ALIGN(ios_pos(f), JL_CACHE_BYTE_ALIGNMENT) - ios_pos(f));
@@ -3422,7 +3444,7 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
34223444
}
34233445
if (_native_data != NULL)
34243446
native_functions = *_native_data;
3425-
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_ext_cis, edges);
3447+
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_ext_cis, edges, &query_cache);
34263448
if (_native_data != NULL)
34273449
native_functions = NULL;
34283450
// make sure we don't run any Julia code concurrently before this point
@@ -3451,6 +3473,8 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
34513473
}
34523474
}
34533475

3476+
destroy_query_cache(&query_cache);
3477+
34543478
JL_GC_POP();
34553479
*s = f;
34563480
if (emit_split)

src/staticdata_utils.c

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -131,63 +131,81 @@ JL_DLLEXPORT void jl_push_newly_inferred(jl_value_t* ci)
131131
JL_UNLOCK(&newly_inferred_mutex);
132132
}
133133

134-
135134
// compute whether a type references something internal to worklist
136135
// and thus could not have existed before deserialize
137136
// and thus does not need delayed unique-ing
138-
static int type_in_worklist(jl_value_t *v) JL_NOTSAFEPOINT
137+
static int type_in_worklist(jl_value_t *v, jl_query_cache *cache) JL_NOTSAFEPOINT
139138
{
140139
if (jl_object_in_image(v))
141140
return 0; // fast-path for rejection
141+
142+
void *cached = HT_NOTFOUND;
143+
if (cache != NULL)
144+
cached = ptrhash_get(&cache->type_in_worklist, v);
145+
146+
// fast-path for memoized results
147+
if (cached != HT_NOTFOUND)
148+
return cached == v;
149+
150+
int result = 0;
142151
if (jl_is_uniontype(v)) {
143152
jl_uniontype_t *u = (jl_uniontype_t*)v;
144-
return type_in_worklist(u->a) ||
145-
type_in_worklist(u->b);
153+
result = type_in_worklist(u->a, cache) ||
154+
type_in_worklist(u->b, cache);
146155
}
147156
else if (jl_is_unionall(v)) {
148157
jl_unionall_t *ua = (jl_unionall_t*)v;
149-
return type_in_worklist((jl_value_t*)ua->var) ||
150-
type_in_worklist(ua->body);
158+
result = type_in_worklist((jl_value_t*)ua->var, cache) ||
159+
type_in_worklist(ua->body, cache);
151160
}
152161
else if (jl_is_typevar(v)) {
153162
jl_tvar_t *tv = (jl_tvar_t*)v;
154-
return type_in_worklist(tv->lb) ||
155-
type_in_worklist(tv->ub);
163+
result = type_in_worklist(tv->lb, cache) ||
164+
type_in_worklist(tv->ub, cache);
156165
}
157166
else if (jl_is_vararg(v)) {
158167
jl_vararg_t *tv = (jl_vararg_t*)v;
159-
if (tv->T && type_in_worklist(tv->T))
160-
return 1;
161-
if (tv->N && type_in_worklist(tv->N))
162-
return 1;
168+
result = ((tv->T && type_in_worklist(tv->T, cache)) ||
169+
(tv->N && type_in_worklist(tv->N, cache)));
163170
}
164171
else if (jl_is_datatype(v)) {
165172
jl_datatype_t *dt = (jl_datatype_t*)v;
166-
if (!jl_object_in_image((jl_value_t*)dt->name))
167-
return 1;
168-
jl_svec_t *tt = dt->parameters;
169-
size_t i, l = jl_svec_len(tt);
170-
for (i = 0; i < l; i++)
171-
if (type_in_worklist(jl_tparam(dt, i)))
172-
return 1;
173+
if (!jl_object_in_image((jl_value_t*)dt->name)) {
174+
result = 1;
175+
}
176+
else {
177+
jl_svec_t *tt = dt->parameters;
178+
size_t i, l = jl_svec_len(tt);
179+
for (i = 0; i < l; i++) {
180+
if (type_in_worklist(jl_tparam(dt, i), cache)) {
181+
result = 1;
182+
break;
183+
}
184+
}
185+
}
173186
}
174187
else {
175-
return type_in_worklist(jl_typeof(v));
188+
return type_in_worklist(jl_typeof(v), cache);
176189
}
177-
return 0;
190+
191+
// Memoize result
192+
if (cache != NULL)
193+
ptrhash_put(&cache->type_in_worklist, (void*)v, result ? (void*)v : NULL);
194+
195+
return result;
178196
}
179197

180198
// When we infer external method instances, ensure they link back to the
181199
// package. Otherwise they might be, e.g., for external macros.
182200
// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
183-
static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, arraylist_t *stack)
201+
static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, arraylist_t *stack, jl_query_cache *query_cache)
184202
{
185203
jl_module_t *mod = mi->def.module;
186204
if (jl_is_method(mod))
187205
mod = ((jl_method_t*)mod)->module;
188206
assert(jl_is_module(mod));
189207
uint8_t is_precompiled = jl_atomic_load_relaxed(&mi->flags) & JL_MI_FLAGS_MASK_PRECOMPILED;
190-
if (is_precompiled || !jl_object_in_image((jl_value_t*)mod) || type_in_worklist(mi->specTypes)) {
208+
if (is_precompiled || !jl_object_in_image((jl_value_t*)mod) || type_in_worklist(mi->specTypes, query_cache)) {
191209
return 1;
192210
}
193211
if (!mi->backedges) {
@@ -211,7 +229,7 @@ static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited,
211229
jl_code_instance_t *be;
212230
i = get_next_edge(mi->backedges, i, NULL, &be);
213231
JL_GC_PROMISE_ROOTED(be); // get_next_edge propagates the edge for us here
214-
int child_found = has_backedge_to_worklist(jl_get_ci_mi(be), visited, stack);
232+
int child_found = has_backedge_to_worklist(jl_get_ci_mi(be), visited, stack, query_cache);
215233
if (child_found == 1 || child_found == 2) {
216234
// found what we were looking for, so terminate early
217235
found = 1;
@@ -243,7 +261,7 @@ static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited,
243261
// from the worklist or explicitly added by a `precompile` statement, and
244262
// (4) are the most recently computed result for that method.
245263
// These will be preserved in the image.
246-
static jl_array_t *queue_external_cis(jl_array_t *list)
264+
static jl_array_t *queue_external_cis(jl_array_t *list, jl_query_cache *query_cache)
247265
{
248266
if (list == NULL)
249267
return NULL;
@@ -262,7 +280,7 @@ static jl_array_t *queue_external_cis(jl_array_t *list)
262280
jl_method_instance_t *mi = jl_get_ci_mi(ci);
263281
jl_method_t *m = mi->def.method;
264282
if (ci->owner == jl_nothing && jl_atomic_load_relaxed(&ci->inferred) && jl_is_method(m) && jl_object_in_image((jl_value_t*)m->module)) {
265-
int found = has_backedge_to_worklist(mi, &visited, &stack);
283+
int found = has_backedge_to_worklist(mi, &visited, &stack, query_cache);
266284
assert(found == 0 || found == 1 || found == 2);
267285
assert(stack.len == 0);
268286
if (found == 1 && jl_atomic_load_relaxed(&ci->max_world) == ~(size_t)0) {

0 commit comments

Comments
 (0)