File tree Expand file tree Collapse file tree 3 files changed +22
-2
lines changed
tensorboard/plugins/histogram Expand file tree Collapse file tree 3 files changed +22
-2
lines changed Original file line number Diff line number Diff line change @@ -81,7 +81,12 @@ def when_nonsingular():
81
81
tf .floor (offsets / bucket_width ), dtype = tf .int32
82
82
)
83
83
clamped_indices = tf .minimum (bucket_indices , bucket_count - 1 )
84
- one_hots = tf .one_hot (clamped_indices , depth = bucket_count )
84
+ # Use float64 instead of float32 to avoid accumulating floating point error
85
+ # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
86
+ # See https://github.com/tensorflow/tensorflow/issues/51419 for details.
87
+ one_hots = tf .one_hot (
88
+ clamped_indices , depth = bucket_count , dtype = tf .float64
89
+ )
85
90
bucket_counts = tf .cast (
86
91
tf .reduce_sum (input_tensor = one_hots , axis = 0 ),
87
92
dtype = tf .float64 ,
Original file line number Diff line number Diff line change @@ -119,6 +119,16 @@ def test_when_bucket_count_not_statically_known(self):
119
119
buckets = tensor_util .make_ndarray (pb .value [0 ].tensor )
120
120
self .assertEqual (buckets .shape , (bucket_count , 3 ))
121
121
122
+ def test_with_large_counts (self ):
123
+ # Check for accumulating floating point errors with large counts (> 2^24).
124
+ # See https://github.com/tensorflow/tensorflow/issues/51419 for details.
125
+ large_count = 20_000_000
126
+ data = [0 ] + [1 ] * large_count
127
+ pb = self .histogram ("large_count" , data = data , buckets = 2 )
128
+ buckets = tensor_util .make_ndarray (pb .value [0 ].tensor )
129
+ self .assertEqual (buckets [0 ][2 ], 1 )
130
+ self .assertEqual (buckets [1 ][2 ], large_count )
131
+
122
132
123
133
class SummaryV1PbTest (SummaryBaseTest , tf .test .TestCase ):
124
134
def histogram (self , * args , ** kwargs ):
Original file line number Diff line number Diff line change @@ -214,7 +214,12 @@ def when_nonsingular():
214
214
tf .floor (offsets / bucket_width ), dtype = tf .int32
215
215
)
216
216
clamped_indices = tf .minimum (bucket_indices , bucket_count - 1 )
217
- one_hots = tf .one_hot (clamped_indices , depth = bucket_count )
217
+ # Use float64 instead of float32 to avoid accumulating floating point error
218
+ # later in tf.reduce_sum when summing more than 2^24 individual `1.0` values.
219
+ # See https://github.com/tensorflow/tensorflow/issues/51419 for details.
220
+ one_hots = tf .one_hot (
221
+ clamped_indices , depth = bucket_count , dtype = tf .float64
222
+ )
218
223
bucket_counts = tf .cast (
219
224
tf .reduce_sum (input_tensor = one_hots , axis = 0 ),
220
225
dtype = tf .float64 ,
You can’t perform that action at this time.
0 commit comments