Skip to content

Commit e812c17

Browse files
mattdangerwtensorflower-gardener
authored andcommitted
Fix RandomCrop layer called on integer input
We had a tf.cond which called tf.image.resize on one path (returning float32) and cropping the input dtype on the other, which lead to a error. Instead we cast the inputs to the layer's compute dtype and do our computation with that type. Fixes #15681 PiperOrigin-RevId: 417686461
1 parent 2049ff5 commit e812c17

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

keras/layers/preprocessing/image_preprocessing.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def __init__(self, height, width, seed=None, **kwargs):
253253
def call(self, inputs, training=True):
254254
if training is None:
255255
training = backend.learning_phase()
256-
inputs = utils.ensure_tensor(inputs)
256+
inputs = utils.ensure_tensor(inputs, dtype=self.compute_dtype)
257257
input_shape = tf.shape(inputs)
258258
h_diff = input_shape[H_AXIS] - self.height
259259
w_diff = input_shape[W_AXIS] - self.width
@@ -266,10 +266,14 @@ def random_crop():
266266
return tf.image.crop_to_bounding_box(inputs, h_start, w_start,
267267
self.height, self.width)
268268

269-
outputs = tf.cond(
269+
def resize():
270+
outputs = smart_resize(inputs, [self.height, self.width])
271+
# smart_resize will always output float32, so we need to re-cast.
272+
return tf.cast(outputs, self.compute_dtype)
273+
274+
return tf.cond(
270275
tf.reduce_all((training, h_diff >= 0, w_diff >= 0)), random_crop,
271-
lambda: smart_resize(inputs, [self.height, self.width]))
272-
return tf.cast(outputs, self.compute_dtype)
276+
resize)
273277

274278
def compute_output_shape(self, input_shape):
275279
input_shape = tf.TensorShape(input_shape).as_list()

keras/layers/preprocessing/image_preprocessing_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import functools
1818
from absl.testing import parameterized
1919

20+
import keras
2021
from keras import keras_parameterized
2122
from keras import testing_utils
2223
from keras.engine import sequential
@@ -398,6 +399,12 @@ def test_unbatched_image(self):
398399
actual_output = layer(inp, training=True)
399400
self.assertAllClose(inp[2:10, 2:10, :], actual_output)
400401

402+
@testing_utils.run_v2_only
403+
def test_uint8_input(self):
404+
inputs = keras.Input((128, 128, 3), batch_size=2, dtype=tf.uint8)
405+
layer = image_preprocessing.RandomCrop(64, 64)
406+
self.assertAllEqual(layer(inputs).dtype, 'float32')
407+
401408
@testing_utils.run_v2_only
402409
def test_output_dtypes(self):
403410
inputs = np.array([[[1], [2]], [[3], [4]]], dtype='float64')

0 commit comments

Comments
 (0)