Skip to content

Special handling of Torch DDP in callback #21081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 22, 2025
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions keras/src/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from keras.src import utils
from keras.src.api_export import keras_export

if backend.backend() == "torch":
from torch.nn.parallel import DistributedDataParallel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please inline the import where it is used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old Java habit. I moved the import inline where needed. I had to split the conditional check. Let me know how this looks.



@keras_export("keras.callbacks.Callback")
class Callback:
Expand Down Expand Up @@ -76,15 +79,27 @@ def set_model(self, model):

@property
def model(self):
if backend.backend() == "jax" and hasattr(
self._model, "jax_state_sync"
if backend.backend() == "torch" and isinstance(
self._model, DistributedDataParallel
):
# With JAX, by default the model state is not
# attached to the model in the middle of an
# epoch. We have to force a sync before
# accessing model state for e.g. checkpointing.
self._model.jax_state_sync()
return self._model
# Keras Callbacks expect to work with Keras models. e.g. |
# ModelCheckpoint and EarlyStopping both attempt to call
# keras-specific APIs on the value returned from this
# property. If this callback was created against a DDP
# wrapper instead of the underlying keras.Model, it is
# likely to fail. Return self._model.module for DDP
# instances instead.
return self._model.module
else:
if backend.backend() == "jax" and hasattr(
self._model, "jax_state_sync"
):
# With JAX, by default the model state is not
# attached to the model in the middle of an
# epoch. We have to force a sync before
# accessing model state for e.g. checkpointing.
self._model.jax_state_sync()
return self._model

@utils.default
def on_batch_begin(self, batch, logs=None):
Expand Down