Skip to content

Commit 5978608

Browse files
committed
Fix test
Signed-off-by: Joaquin Anton <[email protected]>
1 parent ffa5fb0 commit 5978608

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

dali/test/python/operator_1/test_debayer.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@
2222
from test_utils import get_dali_extra_path
2323
from nose_utils import assert_raises
2424
from nose2.tools import params
25+
from test_utils import dali_type
2526
from debayer_test_utils import (
2627
bayer_patterns,
2728
blue_position,
@@ -42,12 +43,12 @@ def read_imgs(num_imgs, dtype, seed):
4243
@pipeline_def
4344
def pipeline():
4445
input, _ = fn.readers.file(file_root=images_dir, random_shuffle=True, seed=seed)
45-
return fn.decoders.image(input, device="cpu", output_type=types.RGB)
46+
return fn.decoders.image(input, device="cpu", output_type=types.RGB, dtype=dali_type(dtype))
4647

4748
pipe = pipeline(batch_size=num_imgs, device_id=0, num_threads=4)
4849
pipe.build()
4950
(batch,) = pipe.run()
50-
return [np.array(img, dtype=dtype) for img in batch]
51+
return [np.array(img) for img in batch]
5152

5253

5354
def read_video(num_sequences, num_frames, height, width, seed=42):
@@ -81,10 +82,6 @@ def prepare_test_imgs(num_samples, dtype):
8182
assert dtype in (np.uint8, np.uint16)
8283
rng = np.random.default_rng(seed=101)
8384
imgs = read_imgs(num_samples, dtype, seed=42 if dtype == np.uint8 else 13)
84-
if dtype == np.uint16:
85-
imgs = [
86-
np.uint16(img) * 256 + np.uint16(rng.uniform(0, 256, size=img.shape)) for img in imgs
87-
]
8885
bayered_imgs = {
8986
pattern: [rgb2bayer(img, pattern) for img in imgs] for pattern in bayer_patterns
9087
}
@@ -154,7 +151,7 @@ def debayer_pipeline():
154151
assert len(debayered_imgs) == len(idxs)
155152
for img_debayered, idx in zip(debayered_imgs, idxs):
156153
baseline = npp_baseline[pattern][idx]
157-
assert np.all(img_debayered == baseline)
154+
np.testing.assert_allclose(img_debayered, baseline)
158155

159156
@params(*itertools.product([1, 11, 184], [np.uint8, np.uint16]))
160157
def test_debayer_per_sample_pattern(self, batch_size, dtype):
@@ -200,7 +197,7 @@ def debayer_pipeline():
200197
assert len(debayered_imgs) == len(patterns) == len(idxs)
201198
for img_debayered, pattern, idx in zip(debayered_imgs, patterns, idxs):
202199
baseline = npp_baseline[pattern][idx]
203-
assert np.all(img_debayered == baseline)
200+
np.testing.assert_allclose(img_debayered, baseline)
204201

205202

206203
class DebayerVideoTest(unittest.TestCase):
@@ -260,7 +257,7 @@ def debayer_pipeline():
260257
assert len(debayered_videos) == len(idxs)
261258
for vid_debayered, idx in zip(debayered_videos, idxs):
262259
baseline = self.npp_baseline[idx]
263-
assert np.all(vid_debayered == baseline)
260+
np.testing.assert_allclose(vid_debayered, baseline)
264261

265262

266263
def source_full_array(shape, dtype):

0 commit comments

Comments
 (0)