Skip to content

Commit a473467

Browse files
authored
Add warp_perspective operator (#5542)
Signed-off-by: Rafal Banas <[email protected]>
1 parent ecf8432 commit a473467

File tree

13 files changed

+1082
-12
lines changed

13 files changed

+1082
-12
lines changed

cmake/Dependencies.common.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ if (BUILD_CVCUDA)
264264
set(DALI_BUILD_PYTHON ${BUILD_PYTHON})
265265
set(BUILD_PYTHON OFF)
266266
# for now we use only median blur from CV-CUDA
267-
set(CV_CUDA_SRC_PATERN medianblur median_blur morphology)
267+
set(CV_CUDA_SRC_PATERN medianblur median_blur morphology warp)
268268
check_and_add_cmake_submodule(${PROJECT_SOURCE_DIR}/third_party/cvcuda)
269269
set(BUILD_PYTHON ${DALI_BUILD_PYTHON})
270270
endif()

dali/operators/image/remap/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,6 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
if (BUILD_CVCUDA)
16+
add_subdirectory(cvcuda)
17+
endif()
18+
1519
# Get all the source files and dump test files
1620
collect_headers(DALI_INST_HDRS PARENT_SCOPE)
1721
collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Get all the source files and dump test files
16+
collect_headers(DALI_INST_HDRS PARENT_SCOPE)
17+
collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE)
18+
collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "dali/operators/image/remap/cvcuda/matrix_adjust.h"
16+
17+
#include <dali/core/util.h>
18+
#include <dali/core/geom/mat.h>
19+
#include <nvcv/cuda/TensorWrap.hpp>
20+
21+
namespace dali {
22+
namespace warp_perspective {
23+
24+
using MatricesWrap = nvcv::cuda::TensorWrap<float, 9 * sizeof(float), sizeof(float)>;
25+
26+
__global__ void adjustMatricesKernel2(MatricesWrap wrap, int batch_size) {
27+
// To adjust the matrix to OpenCV pixel coordinates, we need to apply operators changing
28+
// the coordinates system basis. We do it by multiplying the matrix on both sides
29+
// by the opposite translation matrices.
30+
// The same routine can be used regardless if the inverse_map is used or not
31+
// because inverting the matrix preserves the basis change.
32+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
33+
if (tid >= batch_size) {
34+
return;
35+
}
36+
auto *matrix_ptr = reinterpret_cast<mat3 *>(wrap.ptr(tid));
37+
auto matrix = *matrix_ptr;
38+
39+
mat3 shift = {{{1, 0, 0.5f}, {0, 1, 0.5f}, {0, 0, 1}}};
40+
41+
matrix = matrix * shift;
42+
43+
// shift back
44+
matrix.set_row(0, matrix.row(0) - matrix.row(2) * 0.5f);
45+
matrix.set_row(1, matrix.row(1) - matrix.row(2) * 0.5f);
46+
47+
*matrix_ptr = matrix;
48+
}
49+
50+
void adjustMatrices(nvcv::Tensor &matrices, cudaStream_t stream) {
51+
auto data = *matrices.exportData<nvcv::TensorDataStridedCuda>();
52+
int bs = data.shape()[0];
53+
MatricesWrap wrap(data);
54+
55+
int num_blocks = div_ceil(bs, 256);
56+
int threads_per_block = std::min(bs, 256);
57+
adjustMatricesKernel2<<<num_blocks, threads_per_block, 0, stream>>>(wrap, bs);
58+
}
59+
60+
} // namespace warp_perspective
61+
} // namespace dali
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef DALI_OPERATORS_IMAGE_REMAP_CVCUDA_MATRIX_ADJUST_H_
16+
#define DALI_OPERATORS_IMAGE_REMAP_CVCUDA_MATRIX_ADJUST_H_
17+
18+
#include <dali/core/geom/mat.h>
19+
#include <dali/pipeline/data/tensor.h>
20+
#include <nvcv/Tensor.hpp>
21+
22+
namespace dali {
23+
namespace warp_perspective {
24+
25+
/**
26+
* @brief Modifies (in-place) tensor of perspective matrices to match
27+
* the OpenCV convention of pixel origin (center instead of corner).
28+
*/
29+
void adjustMatrices(nvcv::Tensor &matrices, cudaStream_t stream);
30+
31+
} // namespace warp_perspective
32+
} // namespace dali
33+
34+
#endif // DALI_OPERATORS_IMAGE_REMAP_CVCUDA_MATRIX_ADJUST_H_
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <cvcuda/OpWarpPerspective.hpp>
16+
#include <nvcv/Image.hpp>
17+
#include <nvcv/ImageBatch.hpp>
18+
#include <nvcv/Tensor.hpp>
19+
#include <optional>
20+
#include "dali/core/dev_buffer.h"
21+
#include "dali/core/static_switch.h"
22+
#include "dali/kernels/common/utils.h"
23+
#include "dali/kernels/dynamic_scratchpad.h"
24+
#include "dali/pipeline/operator/arg_helper.h"
25+
#include "dali/pipeline/operator/checkpointing/stateless_operator.h"
26+
#include "dali/pipeline/operator/operator.h"
27+
28+
#include "dali/operators/nvcvop/nvcvop.h"
29+
#include "dali/operators/image/remap/cvcuda/matrix_adjust.h"
30+
31+
namespace dali {
32+
33+
34+
DALI_SCHEMA(experimental__WarpPerspective)
35+
.DocStr(R"doc(
36+
Performs a perspective transform on the images.
37+
)doc")
38+
.NumInput(1, 2)
39+
.InputDox(0, "input", "TensorList of uint8, uint16, int16 or float",
40+
"Input data. Must be images in HWC or CHW layout, or a sequence of those.")
41+
.InputDox(1, "matrix_gpu", "1D TensorList of float",
42+
"Transformation matrix data. Should be used to pass the GPU data. "
43+
"For CPU data, the `matrix` argument should be used.")
44+
.NumOutput(1)
45+
.InputLayout(0, {"HW", "HWC", "FHWC", "CHW", "FCHW"})
46+
.AddOptionalArg<float>("size",
47+
R"code(Output size, in pixels/points.
48+
49+
The channel dimension should be excluded (for example, for RGB images,
50+
specify ``(480,640)``, not ``(480,640,3)``.
51+
)code",
52+
std::vector<float>({}), true)
53+
.AddOptionalArg<float>("matrix",
54+
R"doc(
55+
Perspective transform mapping of destination to source coordinates.
56+
If `inverse_map` argument is set to false, the matrix is interpreted
57+
as a source to destination coordinates mapping.
58+
59+
It is equivalent to OpenCV's ``warpPerspective`` operation with the ``inverse_map`` argument being
60+
analog to the ``WARP_INVERSE_MAP`` flag.
61+
62+
.. note::
63+
Instead of this argument, the operator can take a second positional input, in which
64+
case the matrix can be placed on the GPU.)doc",
65+
std::vector<float>({}), true, true)
66+
.AddOptionalArg("border_mode",
67+
"Border mode to be used when accessing elements outside input image.\n"
68+
"Supported values are: \"constant\", \"replicate\", "
69+
"\"reflect\", \"reflect_101\", \"wrap\".",
70+
"constant")
71+
.AddOptionalArg("interp_type", "Type of interpolation used.", DALI_INTERP_LINEAR)
72+
.AddOptionalArg("pixel_origin", R"doc(Pixel origin. Possible values: "corner", "center".
73+
74+
Determines the meaning of (0, 0) coordinates - "corner" places the origin at the top-left corner of
75+
the top-left pixel (like in OpenGL); "center" places (0, 0) in the center of
76+
the top-left pixel (like in OpenCV).))doc", "corner")
77+
.AddOptionalArg<float>("fill_value",
78+
"Value used to fill areas that are outside the source image when the "
79+
"\"constant\" border_mode is chosen.",
80+
std::vector<float>({}))
81+
.AddOptionalArg<bool>("inverse_map",
82+
"If set to true (default), the matrix is interpreted as "
83+
"destination to source coordinates mapping. "
84+
"Otherwise it's interpreted as source to destination "
85+
"coordinates mapping.", true);
86+
87+
88+
class WarpPerspective : public nvcvop::NVCVSequenceOperator<StatelessOperator> {
89+
public:
90+
explicit WarpPerspective(const OpSpec &spec)
91+
: nvcvop::NVCVSequenceOperator<StatelessOperator>(spec),
92+
border_mode_(nvcvop::GetBorderMode(spec.GetArgument<std::string>("border_mode"))),
93+
interp_type_(nvcvop::GetInterpolationType(spec.GetArgument<DALIInterpType>("interp_type"))),
94+
inverse_map_(spec.GetArgument<bool>("inverse_map")),
95+
fill_value_arg_(spec.GetArgument<std::vector<float>>("fill_value")),
96+
ocv_pixel_(OCVCompatArg(spec.GetArgument<std::string>("pixel_origin"))) {
97+
matrix_data_.SetContiguity(BatchContiguity::Contiguous);
98+
}
99+
100+
bool ShouldExpandChannels(int input_idx) const override {
101+
return true;
102+
}
103+
104+
bool CanInferOutputs() const override {
105+
return true;
106+
}
107+
108+
float4 GetFillValue(int channels) const {
109+
if (fill_value_arg_.size() > 1) {
110+
if (channels > 0) {
111+
if (channels == static_cast<int>(fill_value_arg_.size())) {
112+
float4 fill_value{0, 0, 0, 0};
113+
memcpy(&fill_value, fill_value_arg_.data(), sizeof(decltype(fill_value)));
114+
return fill_value;
115+
} else {
116+
DALI_FAIL(make_string(
117+
"Number of values provided as a fill_value should match the number of channels.\n"
118+
"Number of channels: ",
119+
channels, ". Number of values provided: ", fill_value_arg_.size(), "."));
120+
}
121+
} else {
122+
DALI_FAIL("Only scalar fill_value can be provided when processing data in planar layout.");
123+
}
124+
} else if (fill_value_arg_.size() == 1) {
125+
auto fv = fill_value_arg_[0];
126+
float4 fill_value{fv, fv, fv, fv};
127+
return fill_value;
128+
} else {
129+
return float4{0, 0, 0, 0};
130+
}
131+
}
132+
133+
void ValidateTypes(const Workspace &ws) const {
134+
auto inp_type = ws.Input<GPUBackend>(0).type();
135+
DALI_ENFORCE(inp_type == DALI_UINT8 || inp_type == DALI_INT16 || inp_type == DALI_UINT16 ||
136+
inp_type == DALI_FLOAT,
137+
"The operator accepts the following input types: "
138+
"uint8, int16, uint16, float.");
139+
if (ws.NumInput() > 1) {
140+
auto mat_type = ws.Input<GPUBackend>(1).type();
141+
DALI_ENFORCE(mat_type == DALI_FLOAT,
142+
"Transformation matrix can be provided only as float32 values.");
143+
}
144+
}
145+
146+
bool OCVCompatArg(const std::string &arg) {
147+
if (arg == "corner") {
148+
return false;
149+
} else if (arg == "center") {
150+
return true;
151+
} else {
152+
DALI_FAIL(make_string("Invalid pixel_origin argument: ", arg));
153+
}
154+
}
155+
156+
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
157+
ValidateTypes(ws);
158+
const auto &input = ws.Input<GPUBackend>(0);
159+
auto input_shape = input.shape();
160+
auto input_layout = input.GetLayout();
161+
output_desc.resize(1);
162+
163+
auto output_shape = input_shape;
164+
int channels = (input_layout.find('C') != -1) ? input_shape[0][input_layout.find('C')] : -1;
165+
fill_value_ = GetFillValue(channels);
166+
if (size_arg_.HasExplicitValue()) {
167+
size_arg_.Acquire(spec_, ws, input_shape.size(), TensorShape<1>(2));
168+
for (int i = 0; i < input_shape.size(); i++) {
169+
auto height = std::max<int>(std::roundf(size_arg_[i].data[0]), 1);
170+
auto width = std::max<int>(std::roundf(size_arg_[i].data[1]), 1);
171+
auto out_sample_shape = (channels != -1) ? TensorShape<>({height, width, channels}) :
172+
TensorShape<>({height, width});
173+
output_shape.set_tensor_shape(i, out_sample_shape);
174+
}
175+
}
176+
177+
output_desc[0] = {output_shape, input.type()};
178+
return true;
179+
}
180+
181+
void RunImpl(Workspace &ws) override {
182+
const auto &input = ws.Input<GPUBackend>(0);
183+
auto &output = ws.Output<GPUBackend>(0);
184+
output.SetLayout(input.GetLayout());
185+
186+
kernels::DynamicScratchpad scratchpad({}, AccessOrder(ws.stream()));
187+
188+
nvcv::Tensor matrix{};
189+
if (ws.NumInput() > 1) {
190+
DALI_ENFORCE(!matrix_arg_.HasExplicitValue(),
191+
"Matrix input and `matrix` argument should not be provided at the same time.");
192+
auto &matrix_input = ws.Input<GPUBackend>(1);
193+
DALI_ENFORCE(matrix_input.shape() ==
194+
uniform_list_shape(matrix_input.num_samples(), TensorShape<2>(3, 3)),
195+
make_string("Expected a uniform list of 3x3 matrices. "
196+
"Instead got data with shape: ",
197+
matrix_input.shape()));
198+
199+
matrix_data_.Copy(matrix_input, AccessOrder(ws.stream()));
200+
Tensor<GPUBackend> matrix_tensor = matrix_data_.AsTensor();
201+
matrix = nvcvop::AsTensor(matrix_tensor, "NW", TensorShape<2>{input.num_samples(), 9});
202+
} else {
203+
matrix = AcquireTensorArgument(ws, scratchpad, matrix_arg_, TensorShape<2>{3, 3},
204+
nvcvop::GetDataType<float>(), "W", TensorShape<1>{9});
205+
}
206+
if (!ocv_pixel_) {
207+
warp_perspective::adjustMatrices(matrix, ws.stream());
208+
}
209+
210+
auto input_images = GetInputBatch(ws, 0);
211+
auto output_images = GetOutputBatch(ws, 0);
212+
if (!warp_perspective_ || input.num_samples() > op_batch_size_) {
213+
op_batch_size_ = std::max(op_batch_size_ * 2, input.num_samples());
214+
warp_perspective_.emplace(op_batch_size_);
215+
}
216+
int32_t flags = interp_type_;
217+
if (inverse_map_) {
218+
flags |= NVCV_WARP_INVERSE_MAP;
219+
}
220+
(*warp_perspective_)(ws.stream(), input_images, output_images, matrix, flags, border_mode_,
221+
fill_value_);
222+
}
223+
224+
private:
225+
USE_OPERATOR_MEMBERS();
226+
ArgValue<float, 2> matrix_arg_{"matrix", spec_};
227+
ArgValue<float, 1> size_arg_{"size", spec_};
228+
int op_batch_size_ = 0;
229+
NVCVBorderType border_mode_{NVCV_BORDER_CONSTANT};
230+
NVCVInterpolationType interp_type_{NVCV_INTERP_NEAREST};
231+
bool inverse_map_{false};
232+
std::vector<float> fill_value_arg_{0, 0, 0, 0};
233+
float4 fill_value_{0, 0, 0, 0};
234+
bool ocv_pixel_ = true;
235+
std::optional<cvcuda::WarpPerspective> warp_perspective_{};
236+
TensorList<GPUBackend> matrix_data_{};
237+
};
238+
239+
DALI_REGISTER_OPERATOR(experimental__WarpPerspective, WarpPerspective, GPU);
240+
241+
} // namespace dali

0 commit comments

Comments
 (0)