Skip to content

Commit 3da0abd

Browse files
Special handling of Torch DDP in callback (#21081)
* Special handling of Torch DDP in callback * Use inheritance tree for DDP check Modified DDP check to use isinstance rather than type().__name__ for robustness. Fixed additional whitepace * Fixing comment. * inlining DDP import where its needed.
1 parent 9fea22a commit 3da0abd

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

keras/src/callbacks/callback.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ def set_model(self, model):
7676

7777
@property
7878
def model(self):
79+
if backend.backend() == "torch":
80+
from torch.nn.parallel import DistributedDataParallel
81+
82+
if isinstance(self._model, DistributedDataParallel):
83+
# Keras Callbacks expect to work with Keras models. e.g
84+
# ModelCheckpoint and EarlyStopping both attempt to call
85+
# keras-specific APIs on the value returned from this
86+
# property. If this callback was created against a DDP
87+
# wrapper instead of the underlying keras.Model, it is
88+
# likely to fail. Return self._model.module for DDP
89+
# instances instead.
90+
return self._model.module
91+
7992
if backend.backend() == "jax" and hasattr(
8093
self._model, "jax_state_sync"
8194
):

0 commit comments

Comments
 (0)