Skip to content

Grad of unreduced #29219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,25 @@ def replace_rule_output_symbolic_zeros(
x: JaxTypeOrTracer | SymbolicZero) -> JaxTypeOrTracer | Zero:
return Zero(x.aval) if type(x) is SymbolicZero else x

def ones_like_aval(aval) -> Ones | Array:
if aval.sharding.spec.unreduced:
return Ones(aval)
return aval_ones_likers[type(aval)](aval)


class Ones:
__slots__ = ['aval']

def __init__(self, aval: core.AbstractValue):
self.aval = aval

def __repr__(self) -> str:
return f'Ones({self.aval})'

register_pytree_node(Ones, lambda z: ((), z.aval), lambda aval, _: Ones(aval))

aval_ones_likers: dict[type, Callable[[Any], Array]] = {}


# TODO(mattjj): remove these after fixing downstream users relying on them
zeros_like_p: Primitive = Primitive('zeros_like')
34 changes: 17 additions & 17 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import numpy as np
from contextlib import contextmanager

from jax._src import ad_util
from jax._src import api_util
from jax._src import deprecations
from jax._src import linear_util as lu
Expand Down Expand Up @@ -509,11 +510,10 @@ def value_and_grad_f(*args, **kwargs):
if not has_aux:
ans, vjp_py = _vjp(f_partial, *dyn_args)
else:
ans, vjp_py, aux = _vjp(
f_partial, *dyn_args, has_aux=True)
ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)
_check_scalar(ans)
tree_map(partial(_check_output_dtype_grad, holomorphic), ans)
g = vjp_py(lax_internal._one(ans))
g = vjp_py(ad_util.Ones(core.typeof(ans).to_cotangent_aval()))
g = g[0] if isinstance(argnums, int) else g
if not has_aux:
return ans, g
Expand Down Expand Up @@ -2085,20 +2085,20 @@ def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_):
raise TypeError(msg)
py_args, = py_args_
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
if in_tree != in_tree_expected:
raise ValueError(f"unexpected tree structure of argument to vjp function: "
f"got {in_tree}, but expected to match {in_tree_expected}")
for arg, aval in zip(args, out_primal_avals):
ct_aval = shaped_abstractify(arg)
ct_aval_expected = aval.to_tangent_aval()
if (not core.typecompat(ct_aval, ct_aval_expected) and
not _temporary_dtype_exception(ct_aval, ct_aval_expected)):
raise ValueError(
"unexpected JAX type (e.g. shape/dtype) for argument to vjp function: "
f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} "
f"because the corresponding output of the function {name} had JAX type "
f"{aval.str_short()}")
args, in_tree = tree_flatten(py_args, is_leaf=lambda x: isinstance(x, ad_util.Ones))
# if in_tree != in_tree_expected:
# raise ValueError(f"unexpected tree structure of argument to vjp function: "
# f"got {in_tree}, but expected to match {in_tree_expected}")
# for arg, aval in zip(args, out_primal_avals):
# ct_aval = shaped_abstractify(arg)
# ct_aval_expected = aval.to_tangent_aval()
# if (not core.typecompat(ct_aval, ct_aval_expected) and
# not _temporary_dtype_exception(ct_aval, ct_aval_expected)):
# raise ValueError(
# "unexpected JAX type (e.g. shape/dtype) for argument to vjp function: "
# f"got {ct_aval.str_short()}, but expected {ct_aval_expected.str_short()} "
# f"because the corresponding output of the function {name} had JAX type "
# f"{aval.str_short()}")
ans = fun(*args)
return tree_unflatten(out_tree, ans)

Expand Down
14 changes: 12 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,9 +2057,15 @@ def __hash__(self):
self.vma))

def to_tangent_aval(self):
dtype = primal_dtype_to_tangent_dtype(self.dtype)
return ShapedArray(
self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, sharding=self.sharding, vma=self.vma)
self.shape, dtype, self.weak_type, sharding=self.sharding, vma=self.vma)

def to_cotangent_aval(self):
dtype = primal_dtype_to_tangent_dtype(self.dtype)
sharding = primal_sharding_to_cotangent_sharding(self.sharding)
return ShapedArray(
self.shape, dtype, self.weak_type, sharding=sharding, vma=self.vma)

def str_short(self, short_dtypes=False, mesh_axis_types=False):
return str_short_aval(
Expand Down Expand Up @@ -2106,6 +2112,10 @@ def primal_dtype_to_tangent_dtype(primal_dtype):
else:
return primal_dtype

def primal_sharding_to_cotangent_sharding(sharding):
new_spec = P(*sharding.spec, unreduced=sharding.replicated_axes)
return sharding.with_spec(new_spec)


def pvary(x, axis_name):
axes = (axis_name,) if not isinstance(axis_name, tuple) else axis_name
Expand Down
12 changes: 9 additions & 3 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3358,11 +3358,15 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array:
else:
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
out = broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding)
out = core.pvary(out, tuple(aval.vma))
return out

return core.pvary(out, tuple(aval.vma))
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array

def ones_like_shaped_array(aval: ShapedArray) -> Array:
one = _convert_element_type(1, aval.dtype, aval.weak_type)
out = broadcast(one, aval.shape, out_sharding=aval.sharding)
return core.pvary(out, tuple(aval.vma))
ad_util.aval_ones_likers[ShapedArray] = ones_like_shaped_array

def zeros_like_abstract_ref(aval: state.AbstractRef) -> core.MutableArray:
val = ad_util.zeros_like_aval(aval.inner_aval)
return core.mutable_array(val)
Expand Down Expand Up @@ -7638,6 +7642,8 @@ def _reduce_number_dtype_rule(name, operand, *args, **kw):

def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
assert ad.is_undefined_primal(operand)
if isinstance(cotangent, ad_util.Ones):
return [ad_util.ones_like_aval(operand.aval.to_cotangent_aval())]
input_shape = operand.aval.shape
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions,
Expand Down
8 changes: 8 additions & 0 deletions jax/_src/named_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ def __reduce__(self):
{'memory_kind': self.memory_kind,
'_logical_device_ids': self._logical_device_ids})

@property
def replicated_axes(self):
other_axes = {ax for entry in self.spec
for ax in (entry if isinstance(entry, tuple) else (entry,))
if ax is not None}
other_axes |= set(self.spec.unreduced)
return tuple(n for n in self.mesh.axis_names if n not in other_axes)

@property
def memory_kind(self) -> str | None:
return self._memory_kind
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import numpy as np

from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import config
Expand Down Expand Up @@ -3007,7 +3008,10 @@ def _reshard_impl(x, dst_sharding):
reshard_p.def_impl(_reshard_impl)

def _reshard_transpose_rule(ct, x, dst_sharding):
return [reshard_p.bind(ct, dst_sharding=x.aval.sharding)]
if isinstance(ct, ad_util.Ones):
return [ad_util.ones_like_aval(x.aval.to_cotangent_aval())]
else:
return [reshard_p.bind(ct, dst_sharding=x.aval.sharding)]
ad.deflinear2(reshard_p, _reshard_transpose_rule)

def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding):
Expand Down
69 changes: 68 additions & 1 deletion tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7764,7 +7764,7 @@ def f(x):
@config.use_shardy_partitioner(True)
@jtu.with_explicit_mesh((2, 2), ('x', 'y'))
def test_unreduced_basic(self, mesh):
np_inp = np.arange(16).reshape(8, 2)
np_inp = np.arange(16.).reshape(8, 2)
x = jax.device_put(np_inp, P('x', 'y'))
y = jax.device_put(np_inp.T, P('y', None))
a = jax.device_put(np_inp, P('x', 'y'))
Expand All @@ -7790,6 +7790,73 @@ def f(x, y, a, b):
self.assertIn('unreduced={"y"}', lowered_text)
self.assertTrue(lowered_text.count('unreduced={"y"}') == 3)

# { lambda ; a:f64[8@x,2@y] b:f64[2@y,8] c:f64[8@x,2@y] d:f64[2@y,8]. let
# e:f64[8@x,8] = pjit[
# name=f
# ctx_mesh=Mesh('x': 2, 'y': 2, axis_types=(Explicit, Explicit))
# jaxpr={ lambda ; a:f64[8@x,2@y] b:f64[2@y,8] c:f64[8@x,2@y] d:f64[2@y,8]. let
# f:f64[8@x,8]{U:y} = dot_general[
# dimension_numbers=(([1], [0]), ([], []))
# out_sharding=P('x', unreduced=('y',)))
# preferred_element_type=float64
# ] a b
# g:f64[8@x,8]{U:y} = dot_general[
# dimension_numbers=(([1], [0]), ([], []))
# out_sharding=P('x', unreduced=('y',)))
# preferred_element_type=float64
# ] c d
# h:f64[8@x,8]{U:y} = add f g
# e:f64[8@x,8] = reshard[
# dst_sharding=P('x', None))
# ] h
# in (e,) }
# ] a b c d
# _:f64[] = reduce_sum[axes=(0, 1)] e
# i:f64[] = broadcast_in_dim[
# broadcast_dimensions=()
# shape=()
# sharding=P())
# ] 1.0:f64[]
# j:f64[8@x,8] = broadcast_in_dim[
# broadcast_dimensions=()
# shape=(8, 8)
# sharding=P('x', None))
# ] i
# k:f64[8@x,2@y] l:f64[2@y,8] m:f64[8@x,2@y] n:f64[2@y,8] = pjit[
# name=f
# ctx_mesh=Mesh('x': 2, 'y': 2, axis_types=(Explicit, Explicit))
# jaxpr={ lambda ; b:f64[2@y,8] a:f64[8@x,2@y] d:f64[2@y,8] c:f64[8@x,2@y] j:f64[8@x,8]. let
# o:f64[8@x,8]{U:y} = reshard[
# dst_sharding=P('x', None, unreduced=('y',)))
# ] j
# p:f64[8,2@y] = dot_general[
# dimension_numbers=(([0], [0]), ([], []))
# out_sharding=P(None, 'y'))
# preferred_element_type=float64
# ] o c
# n:f64[2@y,8] = transpose[permutation=(1, 0)] p
# m:f64[8@x,2@y] = dot_general[
# dimension_numbers=(([1], [1]), ([], []))
# out_sharding=P('x', 'y'))
# preferred_element_type=float64
# ] o d
# q:f64[8,2@y] = dot_general[
# dimension_numbers=(([0], [0]), ([], []))
# out_sharding=P(None, 'y'))
# preferred_element_type=float64
# ] o a
# l:f64[2@y,8] = transpose[permutation=(1, 0)] q
# k:f64[8@x,2@y] = dot_general[
# dimension_numbers=(([1], [1]), ([], []))
# out_sharding=P('x', 'y'))
# preferred_element_type=float64
# ] o b
# in (k, l, m, n) }
# ] b a d c j
# in (k, l, m, n) }
print(jax.jit(jax.grad(lambda x, y, a, b: f(x, y, a, b).sum(),
argnums=(0, 1, 2, 3))).trace(x, y, a, b).jaxpr)

@jtu.with_explicit_mesh((2, 2, 1), ('x', 'y', 'z'))
def test_dot_general_unreduced_error(self, mesh):
np_inp = np.arange(16).reshape(8, 2)
Expand Down
Loading