You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a function that process an array of (n,5) with n in the range of 1M-10M. At the same time, I am applying sharding so that each core process n/n_cores rows. However, when the number is not an integer, this will trigger issues. Currently I am adding some padding before the execution and removing it after. However, when this operation takes place, the time increases and the RAM increases a lot. Is this happening because of the copy performed to build a new array after the slicing? I am using jax.jit for both so I was not expecting such dramatic change. My question is:
Is there any other recommended practice to avoid these operations in such cases? Like, is it possible to adjust the sharding to be uneven? I saw this function but I did not find any specific mention to uneven sharding term within the examples: https://docs.jax.dev/en/latest/_autosummary/jax.lax.with_sharding_constraint.html
If not, is there any other change I can apply to my existing code to not experiment such performance changes depending on n?
These are the functions I am using. Thank you!
@partial(jax.jit, static_argnames=['n_padding'])defadd_padding(array: jnp.ndarray, n_padding: int) ->jnp.ndarray:
""" Adds zero padding to the end of an array with shape (n, 5). Args: array: Input array with shape (n, 5). n_padding: Number of padding rows to add (must be a non-negative integer). Returns: Array with shape (n + n_padding, 5), where the new rows are zeros. """padding_shape= (n_padding, array.shape[1])
padding=jnp.zeros(padding_shape, dtype=array.dtype)
returnjnp.concatenate([array, padding], axis=0)
@partial(jax.jit, static_argnames=['n_padding'])defremove_padding(array: jnp.ndarray, *, n_padding: int) ->jnp.ndarray:
""" Removes the last n_padding rows from an array with shape (n + n_padding, 5). Args: array: Input array with shape (n + n_padding, 5). n_padding: Number of padding rows to remove (non-negative integer). Returns: Array with shape (n, 5), without the last n_padding rows. """returnarray[:-n_padding]
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I have a function that process an array of (n,5) with n in the range of 1M-10M. At the same time, I am applying sharding so that each core process n/n_cores rows. However, when the number is not an integer, this will trigger issues. Currently I am adding some padding before the execution and removing it after. However, when this operation takes place, the time increases and the RAM increases a lot. Is this happening because of the copy performed to build a new array after the slicing? I am using
jax.jit
for both so I was not expecting such dramatic change. My question is:Is there any other recommended practice to avoid these operations in such cases? Like, is it possible to adjust the sharding to be uneven? I saw this function but I did not find any specific mention to
uneven sharding
term within the examples: https://docs.jax.dev/en/latest/_autosummary/jax.lax.with_sharding_constraint.htmlIf not, is there any other change I can apply to my existing code to not experiment such performance changes depending on
n
?These are the functions I am using. Thank you!
Beta Was this translation helpful? Give feedback.
All reactions