How to efficiently implement large segment sums #18830
Unanswered
benjaminvatterj
asked this question in
Q&A
Replies: 1 comment 1 reply
-
@benjaminvatterj Hi, did you find any solution? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi!
I have a model that requires summing some output over a very large collection of groups, like what jax.ops.segement_sum does. The segment indices are model constants, which trigger a very slow and painful constant folding when compiling, which I would like to avoid. However, I have yet to find an alternative to segment_sum that isn't very slow or results in a segmentation fault.
Here is a minimal reproducible example of the problem. There are three version of the problem, the first is the original that works really well except when the derivative is called. The second is too slow to use in practice, and the third results in segmentation fault.
jax version: 0.4.20 on CPU (mac studio M2 ultra)
any help would be greatly appreciated!
Beta Was this translation helpful? Give feedback.
All reactions