Skip to content

Commit 90c9335

Browse files
yatbeardna2github
authored andcommitted
histogram: cast tf.reduce_sum input to float64 (tensorflow#5337)
1 parent 58e2b7d commit 90c9335

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

tensorboard/plugins/histogram/summary.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,12 @@ def when_nonsingular():
8181
tf.floor(offsets / bucket_width), dtype=tf.int32
8282
)
8383
clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
84-
one_hots = tf.one_hot(clamped_indices, depth=bucket_count)
84+
# Use float64 instead of float32 to avoid accumulating floating point error
85+
# later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
86+
# See https://github.com/tensorflow/tensorflow/issues/51419 for details.
87+
one_hots = tf.one_hot(
88+
clamped_indices, depth=bucket_count, dtype=tf.float64
89+
)
8590
bucket_counts = tf.cast(
8691
tf.reduce_sum(input_tensor=one_hots, axis=0),
8792
dtype=tf.float64,

tensorboard/plugins/histogram/summary_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ def test_when_bucket_count_not_statically_known(self):
119119
buckets = tensor_util.make_ndarray(pb.value[0].tensor)
120120
self.assertEqual(buckets.shape, (bucket_count, 3))
121121

122+
def test_with_large_counts(self):
123+
# Check for accumulating floating point errors with large counts (> 2^24).
124+
# See https://github.com/tensorflow/tensorflow/issues/51419 for details.
125+
large_count = 20_000_000
126+
data = [0] + [1] * large_count
127+
pb = self.histogram("large_count", data=data, buckets=2)
128+
buckets = tensor_util.make_ndarray(pb.value[0].tensor)
129+
self.assertEqual(buckets[0][2], 1)
130+
self.assertEqual(buckets[1][2], large_count)
131+
122132

123133
class SummaryV1PbTest(SummaryBaseTest, tf.test.TestCase):
124134
def histogram(self, *args, **kwargs):

tensorboard/plugins/histogram/summary_v2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,12 @@ def when_nonsingular():
214214
tf.floor(offsets / bucket_width), dtype=tf.int32
215215
)
216216
clamped_indices = tf.minimum(bucket_indices, bucket_count - 1)
217-
one_hots = tf.one_hot(clamped_indices, depth=bucket_count)
217+
# Use float64 instead of float32 to avoid accumulating floating point error
218+
# later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
219+
# See https://github.com/tensorflow/tensorflow/issues/51419 for details.
220+
one_hots = tf.one_hot(
221+
clamped_indices, depth=bucket_count, dtype=tf.float64
222+
)
218223
bucket_counts = tf.cast(
219224
tf.reduce_sum(input_tensor=one_hots, axis=0),
220225
dtype=tf.float64,

0 commit comments

Comments
 (0)