Skip to content

Commit 4aae578

Browse files
authored
Allow synchronization value to be set on Variables (#21072)
And use on_read synchronization for Metric variables.
1 parent b0a3080 commit 4aae578

File tree

4 files changed

+52
-1
lines changed

4 files changed

+52
-1
lines changed

keras/src/backend/common/variables.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
trainable=True,
9797
autocast=True,
9898
aggregation="none",
99+
synchronization="auto",
99100
name=None,
100101
):
101102
name = name or auto_name(self.__class__.__name__)
@@ -113,13 +114,28 @@ def __init__(
113114
"only_first_replica",
114115
):
115116
raise ValueError(
116-
"Invalid valid for argument `aggregation`. Expected "
117+
"Invalid value for argument `aggregation`. Expected "
117118
"one of `None`, `'none'`, `'mean'`, `'sum'`, "
118119
"`'only_first_replica'`. "
119120
f"Received: aggregation={aggregation}"
120121
)
121122
if aggregation is None:
122123
aggregation = "none"
124+
if synchronization not in (
125+
None,
126+
"none",
127+
"on_read",
128+
"on_write",
129+
"auto",
130+
):
131+
raise ValueError(
132+
"Invalid value for argument `synchronization`. Expected "
133+
"one of `None`, `'none'`, `'on_read'`, `'on_write'`, "
134+
"`'auto'`. "
135+
f"Received: synchronization={synchronization}"
136+
)
137+
if synchronization is None:
138+
synchronization = "none"
123139
self._name = name
124140
parent_path = current_path()
125141
if parent_path:
@@ -133,6 +149,7 @@ def __init__(
133149
self._trainable = bool(trainable)
134150
self._autocast = bool(autocast)
135151
self._aggregation = aggregation
152+
self._synchronization = synchronization
136153
# `self._overwrite_with_gradient` is an internal property to determine
137154
# whether this variable should be overwritten by the computed gradient.
138155
# Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py
@@ -227,6 +244,11 @@ def aggregation(self):
227244
"""The strategy for aggregating this variable."""
228245
return self._aggregation
229246

247+
@property
248+
def synchronization(self):
249+
"""The strategy for synchronizing this variable."""
250+
return self._synchronization
251+
230252
@property
231253
def value(self):
232254
"""The current value of the variable (numpy array or backend tensor)."""

keras/src/backend/tensorflow/core.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _initialize(self, value):
4545
trainable=self.trainable,
4646
name=self.name,
4747
aggregation=self._map_aggregation(self.aggregation),
48+
synchronization=self._map_synchronization(self.synchronization),
4849
)
4950

5051
def _initialize_with_initializer(self, initializer):
@@ -125,6 +126,15 @@ def _map_aggregation(self, aggregation):
125126
}
126127
return mapping[aggregation]
127128

129+
def _map_synchronization(self, synchronization):
130+
mapping = {
131+
"none": tf.VariableSynchronization.NONE,
132+
"on_read": tf.VariableSynchronization.ON_READ,
133+
"on_write": tf.VariableSynchronization.ON_WRITE,
134+
"auto": tf.VariableSynchronization.AUTO,
135+
}
136+
return mapping[synchronization]
137+
128138

129139
def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
130140
if isinstance(x, tf.SparseTensor) and sparse is not None and not sparse:

keras/src/backend/tensorflow/distribute_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,24 @@ def test_variable_aggregation(self):
137137
self.assertEqual(v2.aggregation, "sum")
138138
self.assertEqual(v2.value.aggregation, tf.VariableAggregation.SUM)
139139

140+
def test_variable_synchronization(self):
141+
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"])
142+
143+
with strategy.scope():
144+
x = np.random.random((4, 4))
145+
v1 = backend.Variable(x, dtype="float32")
146+
self.assertEqual(v1.synchronization, "auto")
147+
# AUTO with MirroredStrategy defaults to ON_WRITE
148+
self.assertEqual(
149+
v1.value.synchronization, tf.VariableSynchronization.ON_WRITE
150+
)
151+
152+
v2 = backend.Variable(x, dtype="float32", synchronization="on_read")
153+
self.assertEqual(v2.synchronization, "on_read")
154+
self.assertEqual(
155+
v2.value.synchronization, tf.VariableSynchronization.ON_READ
156+
)
157+
140158
def test_seed_generator(self):
141159
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"])
142160
with strategy.scope():

keras/src/metrics/metric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def add_variable(
201201
dtype=dtype,
202202
trainable=False,
203203
aggregation=aggregation,
204+
synchronization="on_read",
204205
name=name,
205206
)
206207
# Prevent double-tracking

0 commit comments

Comments
 (0)