Skip to content

Commit 18f285c

Browse files
committed
Add source info functions + setter APIs.
Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent ec9c235 commit 18f285c

File tree

4 files changed

+184
-4
lines changed

4 files changed

+184
-4
lines changed

dali/c_api_2/data_objects.cc

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,16 @@ daliResult_t daliTensorResize(
147147
DALI_EPILOG();
148148
}
149149

150+
daliResult_t daliTensorGetBufferPlacement(
151+
daliTensor_h tensor,
152+
daliBufferPlacement_t *out_placement) {
153+
DALI_PROLOG();
154+
auto *ptr = ToPointer(tensor);
155+
CHECK_OUTPUT(out_placement);
156+
*out_placement = ptr->GetBufferPlacement();
157+
DALI_EPILOG();
158+
}
159+
150160
daliResult_t daliTensorSetLayout(
151161
daliTensor_h tensor,
152162
const char *layout) {
@@ -212,6 +222,24 @@ daliResult_t daliTensorGetShape(
212222
DALI_EPILOG();
213223
}
214224

225+
daliResult_t daliTensorGetSourceInfo(
226+
daliTensor_h tensor,
227+
const char **out_source_info) {
228+
DALI_PROLOG();
229+
auto *ptr = ToPointer(tensor);
230+
CHECK_OUTPUT(out_source_info);
231+
*out_source_info = ptr->GetSourceInfo();
232+
DALI_EPILOG();
233+
}
234+
235+
daliResult_t daliTensorSetSourceInfo(
236+
daliTensor_h tensor,
237+
const char *source_info) {
238+
DALI_PROLOG();
239+
ToPointer(tensor)->SetSourceInfo(source_info);
240+
DALI_EPILOG();
241+
}
242+
215243
//////////////////////////////////////////////////////////////////////////////
216244
// TensorList
217245
//////////////////////////////////////////////////////////////////////////////
@@ -294,6 +322,17 @@ daliResult_t daliTensorListResize(
294322
DALI_EPILOG();
295323
}
296324

325+
daliResult_t daliTensorListGetBufferPlacement(
326+
daliTensorList_h tensor_list,
327+
daliBufferPlacement_t *out_placement) {
328+
DALI_PROLOG();
329+
auto *ptr = ToPointer(tensor_list);
330+
CHECK_OUTPUT(out_placement);
331+
*out_placement = ptr->GetBufferPlacement();
332+
DALI_EPILOG();
333+
}
334+
335+
297336
daliResult_t daliTensorListSetLayout(
298337
daliTensorList_h tensor_list,
299338
const char *layout) {
@@ -363,6 +402,26 @@ daliResult_t daliTensorListGetTensorDesc(
363402
DALI_EPILOG();
364403
}
365404

405+
daliResult_t daliTensorListGetSourceInfo(
406+
daliTensorList_h tensor_list,
407+
const char **out_source_info,
408+
int sample_idx) {
409+
DALI_PROLOG();
410+
auto *ptr = ToPointer(tensor_list);
411+
CHECK_OUTPUT(out_source_info);
412+
*out_source_info = ptr->GetSourceInfo(sample_idx);
413+
DALI_EPILOG();
414+
}
415+
416+
daliResult_t daliTensorListSetSourceInfo(
417+
daliTensorList_h tensor_list,
418+
int sample_idx,
419+
const char *source_info) {
420+
DALI_PROLOG();
421+
ToPointer(tensor_list)->SetSourceInfo(sample_idx, source_info);
422+
DALI_EPILOG();
423+
}
424+
366425
daliResult_t daliTensorListViewAsTensor(
367426
daliTensorList_h tensor_list,
368427
daliTensor_h *out_tensor) {

dali/c_api_2/data_objects.h

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ class ITensor : public _DALITensor, public RefCountedObject {
7373

7474
virtual const TensorShape<> &GetShape() const & = 0;
7575

76+
virtual const char *GetSourceInfo() const & = 0;
77+
78+
virtual void SetSourceInfo(const char *source_info) = 0;
79+
7680
template <typename Backend>
7781
const std::shared_ptr<Tensor<Backend>> &Unwrap() const &;
7882

@@ -129,6 +133,10 @@ class ITensorList : public _DALITensorList, public RefCountedObject {
129133

130134
virtual RefCountedPtr<ITensor> ViewAsTensor() const = 0;
131135

136+
virtual const char *GetSourceInfo(int sample) const & = 0;
137+
138+
virtual void SetSourceInfo(int sample, const char *source_info) = 0;
139+
132140
template <typename Backend>
133141
const std::shared_ptr<TensorList<Backend>> &Unwrap() const &;
134142

@@ -297,6 +305,17 @@ class TensorWrapper : public ITensor {
297305
return t_->shape();
298306
}
299307

308+
const char *GetSourceInfo() const & override {
309+
const char *info = t_->GetMeta().GetSourceInfo().c_str();
310+
if (info && !*info)
311+
return nullptr;
312+
return info;
313+
}
314+
315+
void SetSourceInfo(const char *source_info) override {
316+
t_->SetSourceInfo(source_info ? source_info : "");
317+
}
318+
300319
const auto &NativePtr() const & {
301320
return t_;
302321
}
@@ -583,11 +602,9 @@ class TensorListWrapper : public ITensorList {
583602
}
584603

585604
daliTensorDesc_t GetTensorDesc(int sample) const override {
586-
auto &shape = tl_->shape();
587-
if (sample < 0 || sample >= shape.num_samples())
588-
throw std::out_of_range(make_string("The sample index ", sample, " is out of range. "
589-
"Valid indices are [0..", shape.num_samples() - 1, "]."));
605+
ValidateSampleIdx(sample);
590606
daliTensorDesc_t desc{};
607+
auto &shape = tl_->shape();
591608
desc.ndim = shape.sample_dim();
592609
desc.data = tl_->raw_mutable_tensor(sample);
593610
desc.dtype = tl_->type();
@@ -600,6 +617,19 @@ class TensorListWrapper : public ITensorList {
600617
return tl_->shape();
601618
}
602619

620+
const char *GetSourceInfo(int sample) const & override {
621+
ValidateSampleIdx(sample);
622+
const char *info = tl_->GetMeta(sample).GetSourceInfo().c_str();
623+
if (info && !*info)
624+
return nullptr; // return empty string as NULL
625+
return info;
626+
}
627+
628+
void SetSourceInfo(int sample, const char *source_info) override {
629+
ValidateSampleIdx(sample);
630+
tl_->SetSourceInfo(sample, source_info ? source_info : "");
631+
}
632+
603633
RefCountedPtr<ITensor> ViewAsTensor() const override {
604634
if (!tl_->IsContiguous())
605635
throw std::runtime_error(
@@ -629,6 +659,12 @@ class TensorListWrapper : public ITensorList {
629659
return tl_;
630660
}
631661

662+
inline void ValidateSampleIdx(int idx) const {
663+
if (idx < 0 || idx >= tl_->num_samples())
664+
throw std::out_of_range(make_string("The sample index ", idx, " is out of range. "
665+
"Valid indices are [0..", tl_->num_samples() - 1, "]."));
666+
}
667+
632668
private:
633669
std::shared_ptr<TensorList<Backend>> tl_;
634670
};

dali/c_api_2/data_objects_test.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <gtest/gtest.h>
1717
#include "dali/c_api_2/managed_handle.h"
1818
#include "dali/core/span.h"
19+
#include "dali/core/device_guard.h"
1920

2021
TEST(CAPI2_TensorListTest, NullHandle) {
2122
daliTensorList_h h = nullptr;
@@ -65,6 +66,7 @@ inline auto CreateTensorList(daliBufferPlacement_t placement) {
6566
void TestTensorListResize(daliStorageDevice_t storage_device) {
6667
daliBufferPlacement_t placement{};
6768
placement.device_type = storage_device;
69+
placement.pinned = true;
6870
int64_t shapes[] = {
6971
480, 640, 3,
7072
600, 800, 4,
@@ -74,6 +76,13 @@ void TestTensorListResize(daliStorageDevice_t storage_device) {
7476
daliDataType_t dtype = DALI_UINT32;
7577

7678
auto tl = CreateTensorList(placement);
79+
80+
daliBufferPlacement_t test_placement{};
81+
EXPECT_EQ(daliTensorListGetBufferPlacement(tl, &test_placement), DALI_SUCCESS);
82+
EXPECT_EQ(test_placement.device_type, placement.device_type);
83+
EXPECT_EQ(test_placement.device_id, placement.device_id);
84+
EXPECT_EQ(test_placement.pinned, placement.pinned);
85+
7786
EXPECT_EQ(daliTensorListResize(tl, 4, 3, nullptr, dtype, nullptr), DALI_ERROR_INVALID_ARGUMENT);
7887
EXPECT_EQ(daliTensorListResize(tl, -1, 3, shapes, dtype, nullptr), DALI_ERROR_INVALID_ARGUMENT);
7988
EXPECT_EQ(daliTensorListResize(tl, 4, -1, shapes, dtype, nullptr), DALI_ERROR_INVALID_ARGUMENT);
@@ -289,8 +298,14 @@ TEST(CAPI2_TensorListTest, AttachSamples) {
289298

290299

291300
TEST(CAPI2_TensorListTest, ViewAsTensor) {
301+
int num_dev = 0;
302+
CUDA_CALL(cudaGetDeviceCount(&num_dev));
303+
// use the last device
304+
dali::DeviceGuard dg(num_dev - 1);
305+
292306
daliBufferPlacement_t placement{};
293307
placement.device_type = DALI_STORAGE_CPU;
308+
placement.pinned = true;
294309
using element_t = int;
295310
daliDataType_t dtype = dali::type2id<element_t>::value;
296311
dali::TensorListShape<> lshape = dali::uniform_list_shape(4, { 480, 640, 3 });
@@ -327,6 +342,12 @@ TEST(CAPI2_TensorListTest, ViewAsTensor) {
327342
ASSERT_NE(ht, nullptr);
328343
dali::c_api::TensorHandle t(ht);
329344

345+
daliBufferPlacement_t tensor_placement{};
346+
EXPECT_EQ(daliTensorGetBufferPlacement(ht, &tensor_placement), DALI_SUCCESS);
347+
EXPECT_EQ(tensor_placement.device_type, placement.device_type);
348+
EXPECT_EQ(tensor_placement.device_id, placement.device_id);
349+
EXPECT_EQ(tensor_placement.pinned, placement.pinned);
350+
330351
daliTensorDesc_t desc{};
331352
EXPECT_EQ(daliTensorGetDesc(t, &desc), DALI_SUCCESS) << daliGetLastErrorMessage();
332353
EXPECT_EQ(desc.data, data.get());
@@ -507,3 +528,38 @@ TEST(CAPI2_TensorTest, ResizeCPU) {
507528
TEST(CAPI2_TensorTest, ResizeGPU) {
508529
TestTensorResize(DALI_STORAGE_GPU);
509530
}
531+
532+
TEST(CAPI2_TensorTest, SourceInfo) {
533+
auto t = CreateTensor({});
534+
const char *out_src_info = "junk";
535+
EXPECT_EQ(daliTensorGetSourceInfo(t, &out_src_info), DALI_SUCCESS);
536+
EXPECT_EQ(out_src_info, nullptr);
537+
538+
EXPECT_EQ(daliTensorSetSourceInfo(t, "source_info"), DALI_SUCCESS);
539+
EXPECT_EQ(daliTensorGetSourceInfo(t, &out_src_info), DALI_SUCCESS);
540+
EXPECT_STREQ(out_src_info, "source_info");
541+
}
542+
543+
TEST(CAPI2_TensorListTest, SourceInfo) {
544+
auto t = CreateTensorList({});
545+
ASSERT_EQ(daliTensorListResize(t, 5, 0, nullptr, DALI_UINT8, nullptr), DALI_SUCCESS);
546+
547+
const char *out_src_info = "junk";
548+
EXPECT_EQ(daliTensorListGetSourceInfo(t, &out_src_info, 0), DALI_SUCCESS);
549+
EXPECT_EQ(out_src_info, nullptr);
550+
551+
EXPECT_EQ(daliTensorListSetSourceInfo(t, 0, "quick"), DALI_SUCCESS);
552+
EXPECT_EQ(daliTensorListSetSourceInfo(t, 2, "brown"), DALI_SUCCESS);
553+
EXPECT_EQ(daliTensorListSetSourceInfo(t, 4, "fox"), DALI_SUCCESS);
554+
555+
EXPECT_EQ(daliTensorListGetSourceInfo(t, &out_src_info, 0), DALI_SUCCESS);
556+
EXPECT_STREQ(out_src_info, "quick");
557+
EXPECT_EQ(daliTensorListGetSourceInfo(t, &out_src_info, 1), DALI_SUCCESS);
558+
EXPECT_EQ(out_src_info, nullptr);
559+
EXPECT_EQ(daliTensorListGetSourceInfo(t, &out_src_info, 2), DALI_SUCCESS);
560+
EXPECT_STREQ(out_src_info, "brown");
561+
EXPECT_EQ(daliTensorListGetSourceInfo(t, &out_src_info, 3), DALI_SUCCESS);
562+
EXPECT_EQ(out_src_info, nullptr);
563+
EXPECT_EQ(daliTensorListGetSourceInfo(t, &out_src_info, 4), DALI_SUCCESS);
564+
EXPECT_STREQ(out_src_info, "fox");
565+
}

include/dali/dali.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,22 @@ DALI_API daliResult_t daliTensorListGetSourceInfo(
874874
const char **out_source_info,
875875
int sample_idx);
876876

877+
/** Sets the "source info" metadata of a tensor in a list.
878+
*
879+
* A tensor can be associated with a "source info" string, which typically is the file name,
880+
* but can also contain an index in a container, key, etc.
881+
*
882+
* @param tensor_list [in] The tensor list
883+
* @param sample_idx [in] The index of the sample, whose source info will is being set.
884+
* @param source_info [in] A source info string (e.g. filename) to associate with the tensor.
885+
* Passing NULL is equivalent to passing an empty string.
886+
*/
887+
DALI_API daliResult_t daliTensorListSetSourceInfo(
888+
daliTensorList_h tensor_list,
889+
int sample_idx,
890+
const char *source_info);
891+
892+
877893
/** Gets the tensor descriptor of the specified sample.
878894
*
879895
* @param tensor_list [in] The tensor list
@@ -1107,6 +1123,19 @@ DALI_API daliResult_t daliTensorGetSourceInfo(
11071123
daliTensor_h tensor,
11081124
const char **out_source_info);
11091125

1126+
/** Sets the "source info" metadata of a tensor.
1127+
*
1128+
* A tensor can be associated with a "source info" string, which typically is the file name,
1129+
* but can also contain an index in a container, key, etc.
1130+
*
1131+
* @param tensor [in] The tensor
1132+
* @param source_info [in] A source info string (e.g. filename) to associate with the tensor.
1133+
* Passing NULL is equivalent to passing an empty string.
1134+
*/
1135+
DALI_API daliResult_t daliTensorSetSourceInfo(
1136+
daliTensor_h tensor,
1137+
const char *source_info);
1138+
11101139
/** Gets the descriptor of the data in the tensor.
11111140
*
11121141
* @param tensor [in] The tensor

0 commit comments

Comments
 (0)