Skip to content

Commit 67bd7ba

Browse files
authored
C API 2.0: External source info (#5872)
Add an ability to obtain information about pipeline inputs including: - name - device - dtype - ndim - layout --------- Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent 3bfccb2 commit 67bd7ba

File tree

5 files changed

+217
-9
lines changed

5 files changed

+217
-9
lines changed

dali/c_api_2/pipeline.cc

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,74 @@ void PipelineWrapper::Prefetch() {
8888
pipeline_->Prefetch();
8989
}
9090

91+
int PipelineWrapper::GetInputCount() const {
92+
return pipeline_->GetInputOperators().size();
93+
}
94+
95+
daliPipelineIODesc_t PipelineWrapper::GetInputDesc(int idx) const & {
96+
auto &inputs = pipeline_->GetInputOperators();
97+
size_t n = inputs.size();
98+
if (idx < 0 || static_cast<size_t>(idx) >= n)
99+
throw std::out_of_range(make_string(
100+
"The input index ", idx, " is out of range. The valid range is [0..", n-1, "]."));
101+
102+
if (input_names_.size() != n) {
103+
input_names_.clear();
104+
input_names_.reserve(n);
105+
for (auto it = inputs.begin(); it != inputs.end(); it++) {
106+
input_names_.push_back(it->first);
107+
}
108+
}
109+
return GetInputDesc(input_names_[idx]);
110+
}
111+
112+
namespace {
113+
114+
template <typename Backend>
115+
void FillPipelineDesc(daliPipelineIODesc_t &desc, const InputOperator<Backend> &inp) {
116+
int ndim = inp.in_ndim();
117+
if (ndim >= 0) {
118+
desc.ndim_present = true;
119+
desc.ndim = ndim;
120+
}
121+
auto dtype = inp.in_dtype();
122+
if (dtype != DALI_NO_TYPE) {
123+
desc.dtype_present = true;
124+
desc.dtype = dtype;
125+
}
126+
auto &layout = inp.in_layout();
127+
if (layout.size())
128+
desc.layout = layout.c_str();
129+
else
130+
desc.layout = nullptr;
131+
}
132+
133+
} // namespace
134+
135+
daliPipelineIODesc_t PipelineWrapper::GetInputDesc(std::string_view name) const & {
136+
auto &inputs = pipeline_->GetInputOperators();
137+
auto it = inputs.find(name);
138+
if (it == inputs.end())
139+
throw invalid_key(make_string("The input with the name \"", name, "\" was not found."));
140+
141+
daliPipelineIODesc_t desc{};
142+
desc.name = it->first.c_str();
143+
desc.device = it->second->op_type == OpType::GPU ? DALI_STORAGE_GPU : DALI_STORAGE_CPU;
144+
auto *op = pipeline_->GetOperator(name);
145+
if (auto *inp = dynamic_cast<InputOperator<CPUBackend> *>(op))
146+
FillPipelineDesc(desc, *inp);
147+
else if (auto *inp = dynamic_cast<InputOperator<GPUBackend> *>(op))
148+
FillPipelineDesc(desc, *inp);
149+
else if (auto *inp = dynamic_cast<InputOperator<MixedBackend> *>(op))
150+
FillPipelineDesc(desc, *inp);
151+
else
152+
throw std::logic_error(make_string(
153+
"Internal error - the operator \"", name, "\" was found in the input operators map, but "
154+
"it's not an instance of InputOperator<Backend>."));
155+
return desc;
156+
}
157+
158+
91159
int PipelineWrapper::GetFeedCount(std::string_view input_name) {
92160
return pipeline_->InputFeedCount(input_name);
93161
}
@@ -122,21 +190,22 @@ int PipelineWrapper::GetOutputCount() const {
122190
return pipeline_->output_descs().size();
123191
}
124192

125-
daliPipelineOutputDesc_t PipelineWrapper::GetOutputDesc(int idx) const {
193+
daliPipelineIODesc_t PipelineWrapper::GetOutputDesc(int idx) const & {
126194
auto &outputs = pipeline_->output_descs();
127195
int nout = outputs.size();
128196
if (idx < 0 || idx >= nout)
129197
throw std::out_of_range(make_string(
130198
"The output index ", idx, " is out of range. "
131199
"Valid range is [0..", nout-1, "]."));
132200
auto &out = outputs[idx];
133-
daliPipelineOutputDesc_t desc{};
201+
daliPipelineIODesc_t desc{};
134202
desc.device = static_cast<daliStorageDevice_t>(out.device);
135203
desc.dtype = out.dtype;
136204
desc.dtype_present = out.dtype != DALI_NO_TYPE;
137205
desc.name = out.name.c_str();
138206
desc.ndim = out.ndim;
139207
desc.ndim_present = out.ndim >= 0;
208+
desc.layout = out.layout.c_str();
140209
return desc;
141210
}
142211

@@ -256,6 +325,37 @@ daliResult_t daliPipelineFeedInput(
256325
DALI_EPILOG();
257326
}
258327

328+
daliResult_t daliPipelineGetInputCount(daliPipeline_h pipeline, int *out_input_count) {
329+
DALI_PROLOG();
330+
auto p = ToPointer(pipeline);
331+
CHECK_OUTPUT(out_input_count);
332+
*out_input_count = p->GetInputCount();
333+
DALI_EPILOG();
334+
}
335+
336+
DALI_API daliResult_t daliPipelineGetInputDescByIdx(
337+
daliPipeline_h pipeline,
338+
daliPipelineIODesc_t *out_input_desc,
339+
int index) {
340+
DALI_PROLOG();
341+
auto p = ToPointer(pipeline);
342+
CHECK_OUTPUT(out_input_desc);
343+
*out_input_desc = p->GetInputDesc(index);
344+
DALI_EPILOG();
345+
}
346+
347+
DALI_API daliResult_t daliPipelineGetInputDesc(
348+
daliPipeline_h pipeline,
349+
daliPipelineIODesc_t *out_input_desc,
350+
const char *name) {
351+
DALI_PROLOG();
352+
auto p = ToPointer(pipeline);
353+
CHECK_OUTPUT(out_input_desc);
354+
NOT_NULL(name);
355+
*out_input_desc = p->GetInputDesc(name);
356+
DALI_EPILOG();
357+
}
358+
259359
daliResult_t daliPipelineGetOutputCount(daliPipeline_h pipeline, int *out_count) {
260360
DALI_PROLOG();
261361
auto pipe = ToPointer(pipeline);
@@ -266,7 +366,7 @@ daliResult_t daliPipelineGetOutputCount(daliPipeline_h pipeline, int *out_count)
266366

267367
daliResult_t daliPipelineGetOutputDesc(
268368
daliPipeline_h pipeline,
269-
daliPipelineOutputDesc_t *out_desc,
369+
daliPipelineIODesc_t *out_desc,
270370
int index) {
271371
DALI_PROLOG();
272372
auto pipe = ToPointer(pipeline);

dali/c_api_2/pipeline.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,13 @@ class PipelineWrapper : public _DALIPipeline {
6060

6161
int GetOutputCount() const;
6262

63-
daliPipelineOutputDesc_t GetOutputDesc(int idx) const;
63+
daliPipelineIODesc_t GetOutputDesc(int idx) const &;
64+
65+
int GetInputCount() const;
66+
67+
daliPipelineIODesc_t GetInputDesc(int idx) const &;
68+
69+
daliPipelineIODesc_t GetInputDesc(std::string_view name) const &;
6470

6571
/** Retrieves the underlying DALI Pipeline object */
6672
dali::Pipeline *Unwrap() const & {
@@ -77,6 +83,7 @@ class PipelineWrapper : public _DALIPipeline {
7783
AccessOrder order);
7884

7985
std::unique_ptr<Pipeline> pipeline_;
86+
mutable std::vector<std::string_view> input_names_;
8087
};
8188

8289
PipelineWrapper *ToPointer(daliPipeline_h handle);

dali/c_api_2/pipeline_test.cc

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,16 @@ std::string GetPipelineWithExternalSource(
159159
StorageDevice dev_type,
160160
int max_batch_size,
161161
int num_threads,
162-
int device_id) {
162+
int device_id,
163+
bool extended_input_desc = false) {
163164
Pipeline p(max_batch_size, num_threads, device_id);
164165
OpSpec src = OpSpec("ExternalSource")
165166
.AddOutput("out", dev_type)
166167
.AddArg("device", dev_type == StorageDevice::CPU ? "cpu" : "gpu")
167168
.AddArg("name", "ext")
168169
.AddArg("batch_size", max_batch_size);
170+
if (extended_input_desc)
171+
src.AddArg("ndim", 3).AddArg("layout", "HWC").AddArg("dtype", DALI_UINT8);
169172
p.AddOperator(src, "ext");
170173
p.SetOutputDescs({ {"out", to_string(dev_type)} });
171174
return p.SerializeToProtobuf();
@@ -264,7 +267,7 @@ void TestPipelineRun(PipelineType ptype) {
264267
int count;
265268
CHECK_DALI(daliPipelineGetOutputCount(h, &count));
266269
ASSERT_EQ(count, 2);
267-
daliPipelineOutputDesc_t desc{};
270+
daliPipelineIODesc_t desc{};
268271
EXPECT_EQ(daliPipelineGetOutputDesc(h, &desc, -1), DALI_ERROR_OUT_OF_RANGE);
269272
EXPECT_EQ(daliPipelineGetOutputDesc(h, &desc, count), DALI_ERROR_OUT_OF_RANGE);
270273
daliClearLastError();
@@ -502,7 +505,47 @@ TEST(CAPI2_PipelineTest, FeedInputGPUAsync) {
502505
TestFeedInput<GPUBackend>({});
503506
}
504507

508+
TEST(CAPI2_PipelineTest, InputDescSimple) {
509+
auto proto = GetPipelineWithExternalSource(dali::StorageDevice::GPU, 4, 4, 0, false);
510+
daliPipelineParams_t params{};
511+
params.exec_type_present = true;
512+
params.exec_type = DALI_EXEC_DYNAMIC;
505513

514+
auto h = Deserialize(proto, params);
515+
CHECK_DALI(daliPipelineBuild(h));
516+
int count = 0;
517+
CHECK_DALI(daliPipelineGetInputCount(h, &count));
518+
ASSERT_EQ(count, 1);
519+
daliPipelineIODesc_t desc{};
520+
CHECK_DALI(daliPipelineGetInputDescByIdx(h, &desc, 0));
521+
EXPECT_EQ(desc.device, DALI_STORAGE_GPU);
522+
EXPECT_STREQ(desc.name, "ext");
523+
EXPECT_FALSE(desc.ndim_present);
524+
EXPECT_FALSE(desc.dtype_present);
525+
EXPECT_EQ(desc.layout, nullptr);
526+
}
527+
528+
TEST(CAPI2_PipelineTest, InputDescExtended) {
529+
auto proto = GetPipelineWithExternalSource(dali::StorageDevice::CPU, 4, 4, 0, true);
530+
daliPipelineParams_t params{};
531+
params.exec_type_present = true;
532+
params.exec_type = DALI_EXEC_DYNAMIC;
533+
534+
auto h = Deserialize(proto, params);
535+
CHECK_DALI(daliPipelineBuild(h));
536+
int count = 0;
537+
CHECK_DALI(daliPipelineGetInputCount(h, &count));
538+
ASSERT_EQ(count, 1);
539+
daliPipelineIODesc_t desc{};
540+
CHECK_DALI(daliPipelineGetInputDescByIdx(h, &desc, 0));
541+
EXPECT_EQ(desc.device, DALI_STORAGE_CPU);
542+
EXPECT_STREQ(desc.name, "ext");
543+
EXPECT_TRUE(desc.ndim_present);
544+
EXPECT_EQ(desc.ndim, 3);
545+
EXPECT_TRUE(desc.dtype_present);
546+
EXPECT_EQ(desc.dtype, DALI_UINT8);
547+
EXPECT_STREQ(desc.layout, "HWC");
548+
}
506549

507550
} // namespace test
508551
} // namespace c_api

dali/pipeline/pipeline.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,15 @@ class DLL_PUBLIC Pipeline {
263263
*/
264264
DLL_PUBLIC const graph::OpNode *GetInputOperatorNode(std::string_view name);
265265

266+
/**
267+
* @brief Get input operatos as a name-to-node mapping.
268+
*
269+
*/
270+
DLL_PUBLIC const auto &GetInputOperators() const & {
271+
DALI_ENFORCE(built_);
272+
return input_operators_;
273+
}
274+
266275
/** @{ */
267276
/**
268277
* @brief Performs some checks on the user-constructed pipeline, setups data

include/dali/dali.h

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ typedef struct _DALIPipelineParams {
336336
} daliPipelineParams_t;
337337

338338
/** Describes an output of a DALI Pipeline */
339-
typedef struct _DALIPipelineOutputDesc {
339+
typedef struct _DALIPipelineIODesc {
340340
const char *name;
341341
daliStorageDevice_t device;
342342
struct {
@@ -345,7 +345,8 @@ typedef struct _DALIPipelineOutputDesc {
345345
};
346346
daliDataType_t dtype;
347347
int ndim;
348-
} daliPipelineOutputDesc_t;
348+
const char *layout;
349+
} daliPipelineIODesc_t;
349350

350351
/** Creates an empty pipeline. */
351352
DALI_API daliResult_t daliPipelineCreate(
@@ -461,6 +462,54 @@ DALI_API daliResult_t daliPipelineFeedInput(
461462
daliFeedInputFlags_t options,
462463
const cudaStream_t *stream);
463464

465+
/** Gets the number of pipeline inputs.
466+
*
467+
* NOTE: The pipeline must be built before calling this function.
468+
*
469+
* @param pipeline [in] The pipeline
470+
* @param out_input_count [out] A pointer to the location where the number of pipeline inputs is
471+
* stored.
472+
*
473+
* @retval DALI_SUCCESS
474+
* @retval DALI_ERROR_INVALID_OPERATION the pipeline wasn't built before the call
475+
*/
476+
DALI_API daliResult_t daliPipelineGetInputCount(daliPipeline_h pipeline, int *out_input_count);
477+
478+
/** Gets a descriptor of a pipeline input specified by index.
479+
*
480+
* NOTE: The pipeline must be built before calling this function.
481+
*
482+
* @param pipeline [in] The pipeline
483+
* @param out_input_desc [out] A pointer to the location where the descriptor is written.
484+
* @param index [in] The 0-based index of the input. See `daliPipelineGetInputCount`.
485+
*
486+
* @retval DALI_SUCCESS
487+
* @retval DALI_ERROR_INVALID_OPERATION the pipeline wasn't built before the call
488+
* @retval DALI_ERROR_OUT_OF_RANGE the index is not a valid 0-based index of the an input
489+
*/
490+
DALI_API daliResult_t daliPipelineGetInputDescByIdx(
491+
daliPipeline_h pipeline,
492+
daliPipelineIODesc_t *out_input_desc,
493+
int index);
494+
495+
/** Gets a descriptor of a pipeline input specified by its name.
496+
*
497+
* NOTE: The pipeline must be built before calling this function.
498+
*
499+
* @param pipeline [in] The pipeline
500+
* @param out_input_desc [out] A pointer to the location where the descriptor is written.
501+
* @param name [in] The name of the input whose descriptor to obtain.]
502+
*
503+
* @retval DALI_SUCCESS
504+
* @retval DALI_ERROR_INVALID_OPERATION the pipeline wasn't built before the call
505+
* @retval DALI_ERROR_INVALID_KEY if `input_name` is not a valid name of an input of the
506+
* pipeline
507+
*/
508+
DALI_API daliResult_t daliPipelineGetInputDesc(
509+
daliPipeline_h pipeline,
510+
daliPipelineIODesc_t *out_input_desc,
511+
const char *name);
512+
464513
/** Gets the number of pipeline outputs.
465514
*
466515
* @param pipeline [in] The pipeline
@@ -481,7 +530,7 @@ DALI_API daliResult_t daliPipelineGetOutputCount(daliPipeline_h pipeline, int *o
481530
*/
482531
DALI_API daliResult_t daliPipelineGetOutputDesc(
483532
daliPipeline_h pipeline,
484-
daliPipelineOutputDesc_t *out_desc,
533+
daliPipelineIODesc_t *out_desc,
485534
int index);
486535

487536

0 commit comments

Comments
 (0)