Skip to content

Commit ec9c235

Browse files
committed
Tests and numerous fixes to AttachSamples.
Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent d358937 commit ec9c235

File tree

3 files changed

+129
-44
lines changed

3 files changed

+129
-44
lines changed

dali/c_api_2/data_objects.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ class TensorListWrapper : public ITensorList {
452452
"The number of dimensions must not be negative when num_samples is 0.");
453453
else
454454
ndim = samples[0].ndim;
455+
ValidateNDim(ndim);
455456
}
456457
if (dtype == DALI_NO_TYPE) {
457458
if (num_samples == 0)
@@ -461,31 +462,39 @@ class TensorListWrapper : public ITensorList {
461462
}
462463
Validate(dtype);
463464

465+
TensorLayout new_layout = {};
466+
467+
if (!layout) {
468+
if (num_samples > 0) {
469+
new_layout = samples[0].layout;
470+
Validate(new_layout, ndim);
471+
} else if (ndim == tl_->sample_dim()) {
472+
new_layout = tl_->GetLayout();
473+
}
474+
} else {
475+
new_layout = layout;
476+
Validate(new_layout, ndim);
477+
}
478+
464479
for (int i = 0; i < num_samples; i++) {
465480
if (ndim && !samples[i].shape)
466481
throw std::invalid_argument(make_string("Got NULL shape in sample ", i, "."));
467482
if (samples[i].dtype != dtype)
468483
throw std::invalid_argument(make_string("Unexpected data type in sample ", i, ". Got: ",
469484
samples[i].dtype, ", expected ", dtype, "."));
470485
ValidateSampleShape(i, make_cspan(samples[i].shape, samples[i].ndim), ndim);;
486+
if (samples[i].layout && new_layout != samples[i].layout)
487+
throw std::invalid_argument(make_string("Unexpected layout \"", samples[i].layout,
488+
"\" in sample ", i, ". Expected: \"", new_layout, "\"."));
471489

472490
if (!samples[i].data && volume(make_cspan(samples[i].shape, ndim)))
473491
throw std::invalid_argument(make_string(
474492
"Got NULL data pointer in a non-empty sample ", i, "."));
475493
}
476494

477-
TensorLayout new_layout = {};
478-
479-
if (!layout) {
480-
if (ndim == tl_->sample_dim())
481-
new_layout = tl_->GetLayout();
482-
} else {
483-
new_layout = layout;
484-
Validate(new_layout, ndim);
485-
}
486-
487495
tl_->Reset();
488496
tl_->SetSize(num_samples);
497+
tl_->set_type(dtype);
489498
tl_->set_sample_dim(ndim);
490499
tl_->SetLayout(new_layout);
491500

dali/c_api_2/data_objects_test.cc

Lines changed: 106 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ TEST(CAPI2_TensorListTest, CreateDestroy) {
3434
ASSERT_NE(h, nullptr);
3535
dali::c_api::TensorListHandle tl(h);
3636
ASSERT_EQ(h, tl.get());
37-
ASSERT_EQ(r, DALI_SUCCESS);
37+
ASSERT_EQ(r, DALI_SUCCESS) << daliGetLastErrorMessage();
3838

3939
int ref = -1;
40-
EXPECT_EQ(daliTensorListRefCount(h, &ref), DALI_SUCCESS);
40+
EXPECT_EQ(daliTensorListRefCount(h, &ref), DALI_SUCCESS) << daliGetLastErrorMessage();
4141
EXPECT_EQ(ref, 1);
4242
ref = -1;
4343

4444
h = tl.release();
45-
EXPECT_EQ(daliTensorListDecRef(h, &ref), DALI_SUCCESS);
45+
EXPECT_EQ(daliTensorListDecRef(h, &ref), DALI_SUCCESS) << daliGetLastErrorMessage();
4646
EXPECT_EQ(ref, 0);
4747
}
4848

@@ -81,17 +81,21 @@ void TestTensorListResize(daliStorageDevice_t storage_device) {
8181
shapes[0] = -1;
8282
EXPECT_EQ(daliTensorListResize(tl, 4, 3, shapes, dtype, "HWC"), DALI_ERROR_INVALID_ARGUMENT);
8383
shapes[0] = 480;
84-
EXPECT_EQ(daliTensorListResize(tl, 1, 3, shapes, dtype, "HWC"), DALI_SUCCESS);
84+
EXPECT_EQ(daliTensorListResize(tl, 1, 3, shapes, dtype, "HWC"), DALI_SUCCESS)
85+
<< daliGetLastErrorMessage();
8586
// resize, but keep the layout
86-
EXPECT_EQ(daliTensorListResize(tl, 4, 3, shapes, dtype, nullptr), DALI_SUCCESS);
87+
EXPECT_EQ(daliTensorListResize(tl, 4, 3, shapes, dtype, nullptr), DALI_SUCCESS)
88+
<< daliGetLastErrorMessage();
8789

8890
size_t element_size = dali::TypeTable::GetTypeInfo(dtype).size();
8991

90-
EXPECT_EQ(daliTensorListGetShape(tl, nullptr, nullptr, nullptr), DALI_SUCCESS);
92+
EXPECT_EQ(daliTensorListGetShape(tl, nullptr, nullptr, nullptr), DALI_SUCCESS)
93+
<< daliGetLastErrorMessage();
9194
{
9295
int nsamples = -1, ndim = -1;
9396
const int64_t *shape_data = nullptr;
94-
EXPECT_EQ(daliTensorListGetShape(tl, &nsamples, &ndim, &shape_data), DALI_SUCCESS);
97+
EXPECT_EQ(daliTensorListGetShape(tl, &nsamples, &ndim, &shape_data), DALI_SUCCESS)
98+
<< daliGetLastErrorMessage();
9599
ASSERT_NE(shape_data, nullptr);
96100
EXPECT_EQ(nsamples, 4);
97101
EXPECT_EQ(ndim, 3);
@@ -105,13 +109,13 @@ void TestTensorListResize(daliStorageDevice_t storage_device) {
105109
const char *base;
106110
for (int i = 0; i < 4; i++) {
107111
daliTensorDesc_t desc{};
108-
EXPECT_EQ(daliTensorListGetTensorDesc(tl, &desc, i), DALI_SUCCESS);
112+
EXPECT_EQ(daliTensorListGetTensorDesc(tl, &desc, i), DALI_SUCCESS) << daliGetLastErrorMessage();
109113
ASSERT_EQ(desc.ndim, 3);
110-
ASSERT_NE(desc.data, nullptr);
111114
if (i == 0)
112115
base = static_cast<char *>(desc.data);
113116
EXPECT_EQ(desc.data, base + offset);
114117
EXPECT_EQ(desc.dtype, dtype);
118+
ASSERT_NE(desc.shape, nullptr);
115119
for (int j = 0; j < 3; j++)
116120
EXPECT_EQ(desc.shape[j], shapes[3 * i + j]);
117121
size_t sample_bytes = volume(dali::make_cspan(desc.shape, desc.ndim)) * element_size;
@@ -148,7 +152,7 @@ MakeTestDeleter(element_t *expected_data) {
148152
EXPECT_EQ(ctx->buffer_delete_count, 0);
149153
EXPECT_EQ(data, ctx->expected_data);
150154
ctx->buffer_delete_count++;
151-
delete [] static_cast<element_t *>(data);
155+
// do not actually delete the data
152156
};
153157
deleter.destroy_context = [](void *vctx) {
154158
auto *ctx = static_cast<TestDeleterCtx *>(vctx);
@@ -189,21 +193,22 @@ TEST(CAPI2_TensorListTest, AttachBuffer) {
189193
"HWC",
190194
data.get(),
191195
offsets,
192-
deleter), DALI_SUCCESS);
196+
deleter), DALI_SUCCESS) << daliGetLastErrorMessage();;
193197

194-
void *data_ptr = data.release(); // the buffer is now owned by the tensor list
198+
// The deleter doesn't actually delete - we still own the data.
195199

196200
ptrdiff_t offset = 0;
197-
const char *base = static_cast<const char *>(data_ptr);
201+
const char *base = reinterpret_cast<const char *>(data.get());
198202
for (int i = 0; i < 4; i++) {
199203
daliTensorDesc_t desc{};
200-
EXPECT_EQ(daliTensorListGetTensorDesc(tl, &desc, i), DALI_SUCCESS);
204+
EXPECT_EQ(daliTensorListGetTensorDesc(tl, &desc, i), DALI_SUCCESS) << daliGetLastErrorMessage();
201205
ASSERT_EQ(desc.ndim, 3);
202-
ASSERT_NE(desc.data, nullptr);
203206
EXPECT_EQ(desc.data, base + offset);
204207
EXPECT_EQ(desc.dtype, dtype);
208+
ASSERT_NE(desc.shape, nullptr);
205209
for (int j = 0; j < 3; j++)
206210
EXPECT_EQ(desc.shape[j], lshape[i][j]);
211+
EXPECT_STREQ(desc.layout, "HWC");
207212
size_t sample_bytes = volume(dali::make_cspan(desc.shape, desc.ndim)) * sizeof(element_t);
208213
offset += sample_bytes;
209214
}
@@ -215,6 +220,74 @@ TEST(CAPI2_TensorListTest, AttachBuffer) {
215220
}
216221

217222

223+
TEST(CAPI2_TensorListTest, AttachSamples) {
224+
daliBufferPlacement_t placement{};
225+
placement.device_type = DALI_STORAGE_CPU;
226+
using element_t = int;
227+
daliDataType_t dtype = dali::type2id<element_t>::value;
228+
dali::TensorListShape<> lshape({
229+
{ 480, 640, 3 },
230+
{ 600, 800, 4 },
231+
{ 348, 720, 1 },
232+
{ 1080, 1920, 3 }
233+
});
234+
auto size = lshape.num_elements();
235+
int N = lshape.num_samples();
236+
std::vector<std::unique_ptr<element_t>> data(N);
237+
238+
for (int i = 0; i < N; i++) {
239+
data[i].reset(new element_t[size]);
240+
}
241+
242+
std::vector<daliDeleter_t> deleters(N);
243+
std::vector<std::unique_ptr<TestDeleterCtx>> deleter_ctxs(N);
244+
245+
for (int i = 0; i < N; i++) {
246+
std::tie(deleters[i], deleter_ctxs[i]) = MakeTestDeleter(data[i].get());
247+
}
248+
249+
std::vector<daliTensorDesc_t> samples(N);
250+
251+
for (int i = 0; i < N; i++) {
252+
samples[i].ndim = lshape.sample_dim();
253+
samples[i].dtype = dtype;
254+
samples[i].layout = i == 0 ? "HWC" : nullptr;
255+
samples[i].shape = lshape.tensor_shape_span(i).data();
256+
samples[i].data = data[i].get();
257+
}
258+
259+
auto tl = CreateTensorList(placement);
260+
ASSERT_EQ(daliTensorListAttachSamples(
261+
tl,
262+
lshape.num_samples(),
263+
-1,
264+
DALI_NO_TYPE,
265+
nullptr,
266+
samples.data(),
267+
deleters.data()), DALI_SUCCESS) << daliGetLastErrorMessage();
268+
269+
// The deleter doesn't actually delete - we still own the data.
270+
for (int i = 0; i < 4; i++) {
271+
daliTensorDesc_t desc{};
272+
EXPECT_EQ(daliTensorListGetTensorDesc(tl, &desc, i), DALI_SUCCESS) << daliGetLastErrorMessage();
273+
ASSERT_EQ(desc.ndim, 3);
274+
EXPECT_EQ(desc.data, data[i].get());
275+
EXPECT_EQ(desc.dtype, dtype);
276+
ASSERT_NE(desc.shape, nullptr);
277+
for (int j = 0; j < 3; j++)
278+
EXPECT_EQ(desc.shape[j], lshape[i][j]);
279+
EXPECT_STREQ(desc.layout, "HWC");
280+
}
281+
282+
tl.reset();
283+
284+
for (auto &ctx : deleter_ctxs) {
285+
EXPECT_EQ(ctx->buffer_delete_count, 1) << "Buffer deleter not called";
286+
EXPECT_EQ(ctx->context_delete_count, 1) << "Deleter context not destroyed";
287+
}
288+
}
289+
290+
218291
TEST(CAPI2_TensorListTest, ViewAsTensor) {
219292
daliBufferPlacement_t placement{};
220293
placement.device_type = DALI_STORAGE_CPU;
@@ -245,18 +318,18 @@ TEST(CAPI2_TensorListTest, ViewAsTensor) {
245318
"HWC",
246319
data.get(),
247320
offsets,
248-
deleter), DALI_SUCCESS);
321+
deleter), DALI_SUCCESS) << daliGetLastErrorMessage();
249322

250-
void *data_ptr = data.release(); // the buffer is now owned by the tensor list
323+
// The deleter doesn't actually delete - we still own the data.
251324

252325
daliTensor_h ht = nullptr;
253-
EXPECT_EQ(daliTensorListViewAsTensor(tl, &ht), DALI_SUCCESS);
326+
EXPECT_EQ(daliTensorListViewAsTensor(tl, &ht), DALI_SUCCESS) << daliGetLastErrorMessage();
254327
ASSERT_NE(ht, nullptr);
255328
dali::c_api::TensorHandle t(ht);
256329

257330
daliTensorDesc_t desc{};
258-
EXPECT_EQ(daliTensorGetDesc(t, &desc), DALI_SUCCESS);
259-
EXPECT_EQ(desc.data, data_ptr);
331+
EXPECT_EQ(daliTensorGetDesc(t, &desc), DALI_SUCCESS) << daliGetLastErrorMessage();
332+
EXPECT_EQ(desc.data, data.get());
260333
EXPECT_EQ(desc.shape[0], lshape.num_samples());
261334
ASSERT_EQ(desc.ndim, 4);
262335
ASSERT_NE(desc.shape, nullptr);
@@ -265,10 +338,10 @@ TEST(CAPI2_TensorListTest, ViewAsTensor) {
265338
EXPECT_EQ(desc.shape[3], lshape[0][2]);
266339
EXPECT_STREQ(desc.layout, "NHWC");
267340
EXPECT_EQ(desc.dtype, dtype);
268-
EXPECT_EQ(daliTensorGetShape(t, nullptr, nullptr), DALI_SUCCESS);
341+
EXPECT_EQ(daliTensorGetShape(t, nullptr, nullptr), DALI_SUCCESS) << daliGetLastErrorMessage();
269342
int ndim = -1;
270343
const int64_t *shape = nullptr;
271-
EXPECT_EQ(daliTensorGetShape(t, &ndim, &shape), DALI_SUCCESS);
344+
EXPECT_EQ(daliTensorGetShape(t, &ndim, &shape), DALI_SUCCESS) << daliGetLastErrorMessage();
272345
EXPECT_EQ(ndim, 4);
273346
EXPECT_EQ(shape, desc.shape);
274347

@@ -315,9 +388,9 @@ TEST(CAPI2_TensorListTest, ViewAsTensorError) {
315388
"HWC",
316389
data.get(),
317390
offsets,
318-
deleter), DALI_SUCCESS);
391+
deleter), DALI_SUCCESS) << daliGetLastErrorMessage();
319392

320-
void *data_ptr = data.release(); // the buffer is now owned by the tensor list
393+
// The deleter doesn't actually delete - we still own the data.
321394

322395
daliTensor_h ht = nullptr;
323396
EXPECT_EQ(daliTensorListViewAsTensor(tl, &ht), DALI_ERROR_INVALID_OPERATION);
@@ -352,15 +425,15 @@ TEST(CAPI2_TensorTest, CreateDestroy) {
352425
ASSERT_NE(h, nullptr);
353426
dali::c_api::TensorHandle tl(h);
354427
ASSERT_EQ(h, tl.get());
355-
ASSERT_EQ(r, DALI_SUCCESS);
428+
ASSERT_EQ(r, DALI_SUCCESS) << daliGetLastErrorMessage();
356429

357430
int ref = -1;
358-
EXPECT_EQ(daliTensorRefCount(h, &ref), DALI_SUCCESS);
431+
EXPECT_EQ(daliTensorRefCount(h, &ref), DALI_SUCCESS) << daliGetLastErrorMessage();
359432
EXPECT_EQ(ref, 1);
360433
ref = -1;
361434

362435
h = tl.release();
363-
EXPECT_EQ(daliTensorDecRef(h, &ref), DALI_SUCCESS);
436+
EXPECT_EQ(daliTensorDecRef(h, &ref), DALI_SUCCESS) << daliGetLastErrorMessage();
364437
EXPECT_EQ(ref, 0);
365438
}
366439

@@ -396,20 +469,22 @@ void TestTensorResize(daliStorageDevice_t storage_device) {
396469
shape[0] = -1;
397470
EXPECT_EQ(daliTensorResize(t, 3, shape, dtype, "HWC"), DALI_ERROR_INVALID_ARGUMENT);
398471
shape[0] = 1;
399-
EXPECT_EQ(daliTensorResize(t, 3, shape, dtype, "HWC"), DALI_SUCCESS);
472+
EXPECT_EQ(daliTensorResize(t, 3, shape, dtype, "HWC"), DALI_SUCCESS)
473+
<< daliGetLastErrorMessage();
400474

401475
shape[0] = 1080;
402-
EXPECT_EQ(daliTensorResize(t, 3, shape, dtype, nullptr), DALI_SUCCESS);
476+
EXPECT_EQ(daliTensorResize(t, 3, shape, dtype, nullptr), DALI_SUCCESS)
477+
<< daliGetLastErrorMessage();
403478

404479
size_t element_size = dali::TypeTable::GetTypeInfo(dtype).size();
405480

406481
ptrdiff_t offset = 0;
407482
daliTensorDesc_t desc{};
408-
EXPECT_EQ(daliTensorGetDesc(t, &desc), DALI_SUCCESS);
483+
EXPECT_EQ(daliTensorGetDesc(t, &desc), DALI_SUCCESS) << daliGetLastErrorMessage();
409484
ASSERT_EQ(desc.ndim, 3);
410-
ASSERT_NE(desc.data, nullptr);
411485
EXPECT_STREQ(desc.layout, "HWC");
412486
EXPECT_EQ(desc.dtype, dtype);
487+
ASSERT_NE(desc.shape, nullptr);
413488
for (int j = 0; j < 3; j++)
414489
EXPECT_EQ(desc.shape[j], shape[j]);
415490
size_t sample_bytes = volume(dali::make_cspan(desc.shape, desc.ndim)) * element_size;

include/dali/dali.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -723,12 +723,13 @@ DALI_API daliResult_t daliTensorListAttachBuffer(
723723
* @param dtype the type of the element of the tensor;
724724
* if dtype is DALI_NO_TYPE, then the type is taken from samples[0].dtype
725725
* @param layout a layout string describing the order of axes in each sample (e.g. HWC),
726-
* if NULL, and the TensorList's number of dimensions is equal to `ndim`,
727-
* then the current layout is kept;
726+
* if NULL, the layout is taken from samples[0].layout; if it's still NULL,
727+
* the current layout is kept, if possible;
728728
* if `layout` is an empty string, the tensor list's layout is cleared
729729
* @param samples the descriptors of the tensors to be attached to the TensorList;
730730
* the `ndim` and `dtype` of the samples must match and they must match the
731-
* values of `ndim` and `dtype` parameters.
731+
* values of `ndim` and `dtype` parameters; the layout must be either NULL
732+
* or match the `layout` argument (if provided).
732733
* @param sample_deleters optional deleters, one for each sample
733734
*
734735
* NOTE: If the sample_deleters specify the same object multiple times, its destructor must

0 commit comments

Comments
 (0)