Skip to content

Commit 4f7f852

Browse files
authored
Allow ops.stop_gradient() to take a variable. (#20302)
Passing a tensor to `ops.stop_gradient()` always works. However, passing a variable directly would work with the Tensorflow backend but fail with an obscure error message with the JAX backend and the Torch backend, requiring users to write `ops.stop_gradient(variable.value)`. This makes it work directly with variables, which is a common use-case.
1 parent 79adef4 commit 4f7f852

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

keras/src/backend/jax/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,8 @@ def fori_loop(lower, upper, body_fun, init_val):
341341

342342

343343
def stop_gradient(variable):
344+
if isinstance(variable, KerasVariable):
345+
variable = variable.value
344346
return jax.lax.stop_gradient(variable)
345347

346348

keras/src/backend/torch/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,8 @@ def fori_loop(lower, upper, body_fun, init_val):
645645

646646

647647
def stop_gradient(variable):
648+
if isinstance(variable, KerasVariable):
649+
variable = variable.value
648650
# We can't use `.requires_grad_(False)` here since it only
649651
# works when the tensor is a leaf node in the graph.
650652
return variable.detach()

keras/src/ops/core_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def __init__(self):
565565
self.b = self.add_weight(shape=(1,), initializer="zeros")
566566

567567
def call(self, x, training=False):
568-
return x * ops.stop_gradient(self.w.value) + self.b
568+
return x * ops.stop_gradient(self.w) + self.b
569569

570570
model = models.Sequential([ExampleLayer()])
571571
model.compile(

0 commit comments

Comments
 (0)