Skip to content

Commit 7618e64

Browse files
Tasks: don't advance task RNG on task spawn (#49110)
1 parent 8327e85 commit 7618e64

File tree

10 files changed

+259
-43
lines changed

10 files changed

+259
-43
lines changed

base/sysimg.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ let
2727
task.rngState1 = 0x7431eaead385992c
2828
task.rngState2 = 0x503e1d32781c2608
2929
task.rngState3 = 0x3a77f7189200c20b
30+
task.rngState4 = 0x5502376d099035ae
3031

3132
# Stdlibs sorted in dependency, then alphabetical, order by contrib/print_sorted_stdlibs.jl
3233
# Run with the `--exclude-jlls` option to filter out all JLL packages

src/gc.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,9 @@ static void jl_gc_run_finalizers_in_list(jl_task_t *ct, arraylist_t *list) JL_NO
501501
ct->sticky = sticky;
502502
}
503503

504-
static uint64_t finalizer_rngState[4];
504+
static uint64_t finalizer_rngState[JL_RNG_SIZE];
505505

506-
void jl_rng_split(uint64_t to[4], uint64_t from[4]) JL_NOTSAFEPOINT;
506+
void jl_rng_split(uint64_t dst[JL_RNG_SIZE], uint64_t src[JL_RNG_SIZE]) JL_NOTSAFEPOINT;
507507

508508
JL_DLLEXPORT void jl_gc_init_finalizer_rng_state(void)
509509
{
@@ -532,7 +532,7 @@ static void run_finalizers(jl_task_t *ct)
532532
jl_atomic_store_relaxed(&jl_gc_have_pending_finalizers, 0);
533533
arraylist_new(&to_finalize, 0);
534534

535-
uint64_t save_rngState[4];
535+
uint64_t save_rngState[JL_RNG_SIZE];
536536
memcpy(&save_rngState[0], &ct->rngState[0], sizeof(save_rngState));
537537
jl_rng_split(ct->rngState, finalizer_rngState);
538538

src/jltypes.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2769,7 +2769,7 @@ void jl_init_types(void) JL_GC_DISABLED
27692769
NULL,
27702770
jl_any_type,
27712771
jl_emptysvec,
2772-
jl_perm_symsvec(15,
2772+
jl_perm_symsvec(16,
27732773
"next",
27742774
"queue",
27752775
"storage",
@@ -2781,11 +2781,12 @@ void jl_init_types(void) JL_GC_DISABLED
27812781
"rngState1",
27822782
"rngState2",
27832783
"rngState3",
2784+
"rngState4",
27842785
"_state",
27852786
"sticky",
27862787
"_isexception",
27872788
"priority"),
2788-
jl_svec(15,
2789+
jl_svec(16,
27892790
jl_any_type,
27902791
jl_any_type,
27912792
jl_any_type,
@@ -2797,6 +2798,7 @@ void jl_init_types(void) JL_GC_DISABLED
27972798
jl_uint64_type,
27982799
jl_uint64_type,
27992800
jl_uint64_type,
2801+
jl_uint64_type,
28002802
jl_uint8_type,
28012803
jl_bool_type,
28022804
jl_bool_type,

src/julia.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,8 @@ typedef struct _jl_handler_t {
19111911
size_t world_age;
19121912
} jl_handler_t;
19131913

1914+
#define JL_RNG_SIZE 5 // xoshiro 4 + splitmix 1
1915+
19141916
typedef struct _jl_task_t {
19151917
JL_DATA_TYPE
19161918
jl_value_t *next; // invasive linked list for scheduler
@@ -1922,7 +1924,7 @@ typedef struct _jl_task_t {
19221924
jl_function_t *start;
19231925
// 4 byte padding on 32-bit systems
19241926
// uint32_t padding0;
1925-
uint64_t rngState[4];
1927+
uint64_t rngState[JL_RNG_SIZE];
19261928
_Atomic(uint8_t) _state;
19271929
uint8_t sticky; // record whether this Task can be migrated to a new thread
19281930
_Atomic(uint8_t) _isexception; // set if `result` is an exception to throw or that we exited with

src/task.c

Lines changed: 180 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -866,28 +866,187 @@ uint64_t jl_genrandom(uint64_t rngState[4]) JL_NOTSAFEPOINT
866866
return res;
867867
}
868868

869-
void jl_rng_split(uint64_t to[4], uint64_t from[4]) JL_NOTSAFEPOINT
869+
/*
870+
The jl_rng_split function forks a task's RNG state in a way that is essentially
871+
guaranteed to avoid collisions between the RNG streams of all tasks. The main
872+
RNG is the xoshiro256++ RNG whose state is stored in rngState[0..3]. There is
873+
also a small internal RNG used for task forking stored in rngState[4]. This
874+
state is used to iterate a LCG (linear congruential generator), which is then
875+
put through four different variations of the strongest PCG output function,
876+
referred to as PCG-RXS-M-XS-64 [1]. This output function is invertible: it maps
877+
a 64-bit state to 64-bit output; which is one of the reasons it's not
878+
recommended for general purpose RNGs unless space is at a premium, but in our
879+
usage invertibility is actually a benefit, as is explained below.
880+
881+
The goal of jl_rng_split is to perturb the state of each child task's RNG in
882+
such a way each that for an entire tree of tasks spawned starting with a given
883+
state in a root task, no two tasks have the same RNG state. Moreover, we want to
884+
do this in a way that is deterministic and repeatable based on (1) the root
885+
task's seed, (2) how many random numbers are generated, and (3) the task tree
886+
structure. The RNG state of a parent task is allowed to affect the initial RNG
887+
state of a child task, but the mere fact that a child was spawned should not
888+
alter the RNG output of the parent. This second requirement rules out using the
889+
main RNG to seed children -- some separate state must be maintained and changed
890+
upon forking a child task while leaving the main RNG state unchanged.
891+
892+
The basic approach is that used by the DotMix [2] and SplitMix [3] RNG systems:
893+
each task is uniquely identified by a sequence of "pedigree" numbers, indicating
894+
where in the task tree it was spawned. This vector of pedigree coordinates is
895+
then reduced to a single value by computing a dot product with a common vector
896+
of random weights. The DotMix paper provides a proof that this dot product hash
897+
value (referred to as a "compression function") is collision resistant in the
898+
sense the the pairwise collision probability of two distinct tasks is 1/N where
899+
N is the number of possible weight values. Both DotMix and SplitMix use a prime
900+
value of N because the proof requires that the difference between two distinct
901+
pedigree coordinates must be invertible, which is guaranteed by N being prime.
902+
We take a different approach: we instead limit pedigree coordinates to being
903+
binary instead -- when a task spawns a child, both tasks share the same pedigree
904+
prefix, with the parent appending a zero and the child appending a one. This way
905+
a binary pedigree vector uniquely identifies each task. Moreover, since the
906+
coordinates are binary, the difference between coordinates is always one which
907+
is its own inverse regardless of whether N is prime or not. This allows us to
908+
compute the dot product modulo 2^64 using native machine arithmetic, which is
909+
considerably more efficient and simpler to implement than arithmetic in a prime
910+
modulus. It also means that when accumulating the dot product incrementally, as
911+
described in SplitMix, we don't need to multiply weights by anything, we simply
912+
add the random weight for the current task tree depth to the parent's dot
913+
product to derive the child's dot product.
914+
915+
We use the LCG in rngState[4] to derive generate pseudorandom weights for the
916+
dot product. Each time a child is forked, we update the LCG in both parent and
917+
child tasks. In the parent, that's all we have to do -- the main RNG state
918+
remains unchanged (recall that spawning a child should *not* affect subsequence
919+
RNG draws in the parent). The next time the parent forks a child, the dot
920+
product weight used will be different, corresponding to being a level deeper in
921+
the binary task tree. In the child, we use the LCG state to generate four
922+
pseudorandom 64-bit weights (more below) and add each weight to one of the
923+
xoshiro256 state registers, rngState[0..3]. If we assume the main RNG remains
924+
unused in all tasks, then each register rngState[0..3] accumulates a different
925+
Dot/SplitMix dot product hash as additional child tasks are spawned. Each one is
926+
collision resistant with a pairwise collision chance of only 1/2^64. Assuming
927+
that the four pseudorandom 64-bit weight streams are sufficiently independent,
928+
the pairwise collision probability for distinct tasks is 1/2^256. If we somehow
929+
managed to spawn a trillion tasks, the probability of a collision would be on
930+
the order of 1/10^54. Practically impossible. Put another way, this is the same
931+
as the probability of two SHA256 hash values accidentally colliding, which we
932+
generally consider so unlikely as not to be worth worrying about.
933+
934+
What about the random "junk" that's in the xoshiro256 state registers from
935+
normal use of the RNG? For a tree of tasks spawned with no intervening samples
936+
taken from the main RNG, all tasks start with the same junk which doesn't affect
937+
the chance of collision. The Dot/SplitMix papers even suggest adding a random
938+
base value to the dot product, so we can consider whatever happens to be in the
939+
xoshiro256 registers to be that. What if the main RNG gets used between task
940+
forks? In that case, the initial state registers will be different. The DotMix
941+
collision resistance proof doesn't apply without modification, but we can
942+
generalize the setup by adding a different base constant to each compression
943+
function and observe that we still have a 1/N chance of the weight value
944+
matching that exact difference. This proves collision resistance even between
945+
tasks whose dot product hashes are computed with arbitrary offsets. We can
946+
conclude that this scheme provides collision resistance even in the face of
947+
different starting states of the main RNG. Does this seem too good to be true?
948+
Perhaps another way of thinking about it will help. Suppose we seeded each task
949+
completely randomly. Then there would also be a 1/2^256 chance of collision,
950+
just as the DotMix proof gives. Essentially what the proof is telling us is that
951+
if the weights are chosen uniformly and uncorrelated with the rest of the
952+
compression function, then the dot product construction is a good enough way to
953+
pseudorandomly seed each task. From that perspective, it's easier to believe
954+
that adding an arbitrary constant to each seed doesn't worsen its randomness.
955+
956+
This leaves us with the question of how to generate four pseudorandom weights to
957+
add to the rngState[0..3] registers at each depth of the task tree. The scheme
958+
used here is that a single 64-bit LCG state is iterated in both parent and child
959+
at each task fork, and four different variations of the PCG-RXS-M-XS-64 output
960+
function are applied to that state to generate four different pseudorandom
961+
weights. Another obvious way to generate four weights would be to iterate the
962+
LCG four times per task split. There are two main reasons we've chosen to use
963+
four output variants instead:
964+
965+
1. Advancing four times per fork reduces the set of possible weights that each
966+
register can be perturbed by from 2^64 to 2^60. Since collision resistance is
967+
proportional to the number of possible weight values, that would reduce
968+
collision resistance.
969+
970+
2. It's easier to compute four PCG output variants in parallel. Iterating the
971+
LCG is inherently sequential. Each PCG variant can be computed independently
972+
from the LCG state. All four can even be computed at once with SIMD vector
973+
instructions, but the compiler doesn't currently choose to do that.
974+
975+
A key question is whether the approach of using four variations of PCG-RXS-M-XS
976+
is sufficiently random both within and between streams to provide the collision
977+
resistance we expect. We obviously can't test that with 256 bits, but we have
978+
tested it with a reduced state analogue using four PCG-RXS-M-XS-8 output
979+
variations applied to a common 8-bit LCG. Test results do indicate sufficient
980+
independence: a single register has collisions at 2^5 while four registers only
981+
start having collisions at 2^20, which is actually better scaling of collision
982+
resistance than we expect in theory. In theory, with one byte of resistance we
983+
have a 50% chance of some collision at 20, which matches, but four bytes gives a
984+
50% chance of collision at 2^17 and our (reduced size analogue) construction is
985+
still collision free at 2^19. This may be due to the next observation, which guarantees collision avoidance for certain shapes of task trees as a result of using an
986+
invertible RNG to generate weights.
987+
988+
In the specific case where a parent task spawns a sequence of child tasks with
989+
no intervening usage of its main RNG, the parent and child tasks are actually
990+
_guaranteed_ to have different RNG states. This is true because the four PCG
991+
streams each produce every possible 2^64 bit output exactly once in the full
992+
2^64 period of the LCG generator. This is considered a weakness of PCG-RXS-M-XS
993+
when used as a general purpose RNG, but is quite beneficial in this application.
994+
Since each of up to 2^64 children will be perturbed by different weights, they
995+
cannot have hash collisions. What about parent colliding with child? That can
996+
only happen if all four main RNG registers are perturbed by exactly zero. This
997+
seems unlikely, but could it occur? Consider this part of each output function:
998+
999+
p ^= p >> ((p >> 59) + 5);
1000+
p *= m[i];
1001+
p ^= p >> 43
1002+
1003+
It's easy to check that this maps zero to zero. An unchanged parent RNG can only
1004+
happen if all four `p` values are zero at the end of this, which implies that
1005+
they were all zero at the beginning. However, that is impossible since the four
1006+
`p` values differ from `x` by different additive constants, so they cannot all
1007+
be zero. Stated more generally, this non-collision property: assuming the main
1008+
RNG isn't used between task forks, sibling and parent tasks cannot have RNG
1009+
collisions. If the task tree structure is more deeply nested or if there are
1010+
intervening uses of the main RNG, we're back to relying on "merely" 256 bits of
1011+
collision resistance, but it's nice to know that in what is likely the most
1012+
common case, RNG collisions are actually impossible. This fact may also explain
1013+
better-than-theoretical collision resistance observed in our experiment with a
1014+
reduced size analogue of our hashing system.
1015+
1016+
[1]: https://www.pcg-random.org/pdf/hmc-cs-2014-0905.pdf
1017+
1018+
[2]: http://supertech.csail.mit.edu/papers/dprng.pdf
1019+
1020+
[3]: https://gee.cs.oswego.edu/dl/papers/oopsla14.pdf
1021+
*/
1022+
void jl_rng_split(uint64_t dst[JL_RNG_SIZE], uint64_t src[JL_RNG_SIZE]) JL_NOTSAFEPOINT
8701023
{
871-
/* TODO: consider a less ad-hoc construction
872-
Ideally we could just use the output of the random stream to seed the initial
873-
state of the child. Out of an overabundance of caution we multiply with
874-
effectively random coefficients, to break possible self-interactions.
875-
876-
It is not the goal to mix bits -- we work under the assumption that the
877-
source is well-seeded, and its output looks effectively random.
878-
However, xoshiro has never been studied in the mode where we seed the
879-
initial state with the output of another xoshiro instance.
880-
881-
Constants have nothing up their sleeve:
882-
0x02011ce34bce797f == hash(UInt(1))|0x01
883-
0x5a94851fb48a6e05 == hash(UInt(2))|0x01
884-
0x3688cf5d48899fa7 == hash(UInt(3))|0x01
885-
0x867b4bb4c42e5661 == hash(UInt(4))|0x01
886-
*/
887-
to[0] = 0x02011ce34bce797f * jl_genrandom(from);
888-
to[1] = 0x5a94851fb48a6e05 * jl_genrandom(from);
889-
to[2] = 0x3688cf5d48899fa7 * jl_genrandom(from);
890-
to[3] = 0x867b4bb4c42e5661 * jl_genrandom(from);
1024+
// load and advance the internal LCG state
1025+
uint64_t x = src[4];
1026+
src[4] = dst[4] = x * 0xd1342543de82ef95 + 1;
1027+
// high spectrum multiplier from https://arxiv.org/abs/2001.05304
1028+
1029+
static const uint64_t a[4] = {
1030+
0xe5f8fa077b92a8a8, // random additive offsets...
1031+
0x7a0cd918958c124d,
1032+
0x86222f7d388588d4,
1033+
0xd30cbd35f2b64f52
1034+
};
1035+
static const uint64_t m[4] = {
1036+
0xaef17502108ef2d9, // standard PCG multiplier
1037+
0xf34026eeb86766af, // random odd multipliers...
1038+
0x38fd70ad58dd9fbb,
1039+
0x6677f9b93ab0c04d
1040+
};
1041+
1042+
// PCG-RXS-M-XS output with four variants
1043+
for (int i = 0; i < 4; i++) {
1044+
uint64_t p = x + a[i];
1045+
p ^= p >> ((p >> 59) + 5);
1046+
p *= m[i];
1047+
p ^= p >> 43;
1048+
dst[i] = src[i] + p; // SplitMix dot product
1049+
}
8911050
}
8921051

8931052
JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion_future, size_t ssize)

stdlib/Random/src/Xoshiro.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,17 @@ struct TaskLocalRNG <: AbstractRNG end
113113
TaskLocalRNG(::Nothing) = TaskLocalRNG()
114114
rng_native_52(::TaskLocalRNG) = UInt64
115115

116-
function setstate!(x::TaskLocalRNG, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
116+
function setstate!(
117+
x::TaskLocalRNG,
118+
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
119+
s4::UInt64 = 1s0 + 3s1 + 5s2 + 7s3, # internal splitmix state
120+
)
117121
t = current_task()
118122
t.rngState0 = s0
119123
t.rngState1 = s1
120124
t.rngState2 = s2
121125
t.rngState3 = s3
126+
t.rngState4 = s4
122127
x
123128
end
124129

@@ -128,11 +133,11 @@ end
128133
tmp = s0 + s3
129134
res = ((tmp << 23) | (tmp >> 41)) + s0
130135
t = s1 << 17
131-
s2 = xor(s2, s0)
132-
s3 = xor(s3, s1)
133-
s1 = xor(s1, s2)
134-
s0 = xor(s0, s3)
135-
s2 = xor(s2, t)
136+
s2 ⊻= s0
137+
s3 ⊻= s1
138+
s1 ⊻= s2
139+
s0 ⊻= s3
140+
s2 ⊻= t
136141
s3 = s3 << 45 | s3 >> 19
137142
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
138143
res
@@ -159,7 +164,7 @@ seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(s
159164
@inline function rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt128})
160165
first = rand(rng, UInt64)
161166
second = rand(rng,UInt64)
162-
second + UInt128(first)<<64
167+
second + UInt128(first) << 64
163168
end
164169

165170
@inline rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128
@@ -178,14 +183,14 @@ end
178183

179184
function copy!(dst::TaskLocalRNG, src::Xoshiro)
180185
t = current_task()
181-
t.rngState0, t.rngState1, t.rngState2, t.rngState3 = src.s0, src.s1, src.s2, src.s3
182-
dst
186+
setstate!(dst, src.s0, src.s1, src.s2, src.s3)
187+
return dst
183188
end
184189

185190
function copy!(dst::Xoshiro, src::TaskLocalRNG)
186191
t = current_task()
187-
dst.s0, dst.s1, dst.s2, dst.s3 = t.rngState0, t.rngState1, t.rngState2, t.rngState3
188-
dst
192+
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3)
193+
return dst
189194
end
190195

191196
function ==(a::Xoshiro, b::TaskLocalRNG)

0 commit comments

Comments
 (0)