1
- # Copyright (c) 2022-2023 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # Copyright (c) 2022-2024 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
22
22
from test_utils import get_dali_extra_path
23
23
from nose_utils import assert_raises
24
24
from nose2 .tools import params
25
+ from test_utils import dali_type
25
26
from debayer_test_utils import (
26
27
bayer_patterns ,
27
28
blue_position ,
@@ -42,12 +43,12 @@ def read_imgs(num_imgs, dtype, seed):
42
43
@pipeline_def
43
44
def pipeline ():
44
45
input , _ = fn .readers .file (file_root = images_dir , random_shuffle = True , seed = seed )
45
- return fn .decoders .image (input , device = "cpu" , output_type = types .RGB )
46
+ return fn .decoders .image (input , device = "cpu" , output_type = types .RGB , dtype = dali_type ( dtype ) )
46
47
47
48
pipe = pipeline (batch_size = num_imgs , device_id = 0 , num_threads = 4 )
48
49
pipe .build ()
49
50
(batch ,) = pipe .run ()
50
- return [np .array (img , dtype = dtype ) for img in batch ]
51
+ return [np .array (img ) for img in batch ]
51
52
52
53
53
54
def read_video (num_sequences , num_frames , height , width , seed = 42 ):
@@ -81,10 +82,6 @@ def prepare_test_imgs(num_samples, dtype):
81
82
assert dtype in (np .uint8 , np .uint16 )
82
83
rng = np .random .default_rng (seed = 101 )
83
84
imgs = read_imgs (num_samples , dtype , seed = 42 if dtype == np .uint8 else 13 )
84
- if dtype == np .uint16 :
85
- imgs = [
86
- np .uint16 (img ) * 256 + np .uint16 (rng .uniform (0 , 256 , size = img .shape )) for img in imgs
87
- ]
88
85
bayered_imgs = {
89
86
pattern : [rgb2bayer (img , pattern ) for img in imgs ] for pattern in bayer_patterns
90
87
}
@@ -154,7 +151,7 @@ def debayer_pipeline():
154
151
assert len (debayered_imgs ) == len (idxs )
155
152
for img_debayered , idx in zip (debayered_imgs , idxs ):
156
153
baseline = npp_baseline [pattern ][idx ]
157
- assert np .all (img_debayered == baseline )
154
+ np .testing . assert_allclose (img_debayered , baseline )
158
155
159
156
@params (* itertools .product ([1 , 11 , 184 ], [np .uint8 , np .uint16 ]))
160
157
def test_debayer_per_sample_pattern (self , batch_size , dtype ):
@@ -200,7 +197,7 @@ def debayer_pipeline():
200
197
assert len (debayered_imgs ) == len (patterns ) == len (idxs )
201
198
for img_debayered , pattern , idx in zip (debayered_imgs , patterns , idxs ):
202
199
baseline = npp_baseline [pattern ][idx ]
203
- assert np .all (img_debayered == baseline )
200
+ np .testing . assert_allclose (img_debayered , baseline )
204
201
205
202
206
203
class DebayerVideoTest (unittest .TestCase ):
@@ -260,7 +257,7 @@ def debayer_pipeline():
260
257
assert len (debayered_videos ) == len (idxs )
261
258
for vid_debayered , idx in zip (debayered_videos , idxs ):
262
259
baseline = self .npp_baseline [idx ]
263
- assert np .all (vid_debayered == baseline )
260
+ np .testing . assert_allclose (vid_debayered , baseline )
264
261
265
262
266
263
def source_full_array (shape , dtype ):
0 commit comments