Skip to content

[MPS] deformable conv2d kernel #9017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,7 @@ def test_batched_nms_implementations(self, seed):

class TestDeformConv:
dtype = torch.float64
mps_dtype = torch.float32

def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
stride_h, stride_w = _pair(stride)
Expand Down Expand Up @@ -1050,12 +1051,11 @@ def test_is_leaf_node(self, device):
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_forward(self, device, contiguous, batch_sz, dtype=None):
dtype = dtype or self.dtype
dtype = self.mps_dtype if device == "mps" else dtype or self.dtype
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
in_channels = 6
out_channels = 2
Expand Down Expand Up @@ -1201,13 +1201,50 @@ def test_forward_scriptability(self):
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))


optests.generate_opcheck_tests(
testcase=TestDeformConv,
namespaces=["torchvision"],
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
additional_decorators=[],
test_utils=OPTESTS,
)
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64))
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("requires_grad", (True, False))
def test_deform_conv2d_opcheck(dtype, device, requires_grad):
batch_size, channels_in, height, width = 1, 6, 10, 10
kernel_size = (3, 3)
stride = (1, 1)
padding = (1, 1)
dilation = (1, 1)
groups = 2
out_channels = 4
out_h = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1
out_w = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
x = torch.randn(batch_size, channels_in, height, width, dtype=dtype, device=device, requires_grad=requires_grad)
offset = torch.randn(batch_size, 2 * kernel_size[0] * kernel_size[1], out_h, out_w,
dtype=dtype, device=device, requires_grad=requires_grad)
weight = torch.randn(out_channels, channels_in // groups, kernel_size[0], kernel_size[1],
dtype=dtype, device=device, requires_grad=requires_grad)
bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad)
use_mask = True
mask = torch.sigmoid(torch.randn(
batch_size,
kernel_size[0] * kernel_size[1],
out_h,
out_w,
dtype=dtype, device=device, requires_grad=requires_grad
))
kwargs = {
"offset": offset,
"weight": weight,
"bias": bias,
"stride_h": stride[0],
"stride_w": stride[1],
"pad_h": padding[0],
"pad_w": padding[1],
"dilation_h": dilation[0],
"dilation_w": dilation[1],
"groups": groups,
"offset_groups": 1,
"use_mask": use_mask,
"mask": mask, # no modulation in this test
}
optests.opcheck(torch.ops.torchvision.deform_conv2d, args=(x,), kwargs=kwargs)



class TestFrozenBNT:
Expand Down
149 changes: 149 additions & 0 deletions torchvision/csrc/ops/mps/deform_conv2d_kernel.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#include <ATen/ATen.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/native/mps/OperationUtils.h>
#include "mps_kernels.h"

namespace vision {
namespace ops {

namespace {

at::Tensor deform_conv2d_forward_kernel(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask) {
using namespace at::native::mps;
at::Tensor input_c = input.contiguous();
at::Tensor weight_c = weight.contiguous();
at::Tensor offset_c = offset.contiguous();
at::Tensor mask_c = mask.contiguous();
at::Tensor bias_c = bias.contiguous();

TORCH_CHECK(input_c.ndimension() == 4, "Input tensor must be 4D");
TORCH_CHECK(weight_c.ndimension() == 4, "Weight tensor must be 4D");
TORCH_CHECK(offset_c.ndimension() == 4, "Offset tensor must be 4D");
TORCH_CHECK(!use_mask || mask_c.ndimension() == 4, "Mask tensor must be 4D if use_mask is true");
TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor");
TORCH_CHECK(weight.is_mps(), "weight must be a MPS tensor");
TORCH_CHECK(offset.is_mps(), "offset must be a MPS tensor");
TORCH_CHECK(mask.is_mps(), "mask must be a MPS tensor");
TORCH_CHECK(bias.is_mps(), "bias must be a MPS tensor");

at::DeviceGuard guard(input_c.device());

uint32_t batch = input_c.size(0);
uint32_t in_channels = input_c.size(1);
uint32_t in_h = input_c.size(2);
uint32_t in_w = input_c.size(3);
uint32_t weight_h = weight_c.size(2);
uint32_t weight_w = weight_c.size(3);
uint32_t out_channels = weight_c.size(0);
uint32_t ker_h = dilation_h * (weight_h - 1) + 1;
uint32_t ker_w = dilation_w * (weight_w - 1) + 1;
uint32_t out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
uint32_t out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
uint32_t pad_h_u = static_cast<uint32_t>(pad_h);
uint32_t pad_w_u = static_cast<uint32_t>(pad_w);
uint32_t stride_h_u = static_cast<uint32_t>(stride_h);
uint32_t stride_w_u = static_cast<uint32_t>(stride_w);
uint32_t dilation_h_u = static_cast<uint32_t>(dilation_h);
uint32_t dilation_w_u = static_cast<uint32_t>(dilation_w);

TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels,
"Input channels (", in_channels,
") must equal weight.size(1) * n_weight_grps (", weight_c.size(1), " * ", n_weight_grps, ")");
TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0,
"Weight tensor's out channels (", weight_c.size(0),
") must be divisible by n_weight_grps (", n_weight_grps, ")");
TORCH_CHECK(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w,
"Offset tensor shape[1] is invalid: got ", offset_c.size(1),
", expected ", n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w,
"Mask tensor shape[1] is invalid: got ", mask_c.size(1),
", expected ", n_offset_grps * weight_h * weight_w);
TORCH_CHECK(in_channels % n_offset_grps == 0,
"Input tensor channels (", in_channels,
") must be divisible by n_offset_grps (", n_offset_grps, ")");
TORCH_CHECK(offset_c.size(0) == batch,
"Offset tensor batch size (", offset_c.size(0),
") must match input tensor batch size (", batch, ")");
TORCH_CHECK(offset_c.size(2) == out_h && offset_c.size(3) == out_w,
"Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3),
") must match calculated output dimensions (", out_h, ", ", out_w, ")");
TORCH_CHECK(!use_mask || mask_c.size(0) == batch,
"Mask tensor batch size (", mask_c.size(0),
") must match input tensor batch size (", batch, ")");
TORCH_CHECK(!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w),
"Mask tensor spatial dimensions (", mask_c.size(2), ", ", mask_c.size(3),
") must match calculated output dimensions (", out_h, ", ", out_w, ")");
TORCH_CHECK(out_h > 0 && out_w > 0,
"Calculated output size too small - out_h: ", out_h, " out_w: ", out_w);

auto columns = at::empty({in_channels * weight_h * weight_w, batch * out_h * out_w}, input_c.options());

id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_c);
id<MTLBuffer> offsetBuffer = getMTLBufferStorage(offset_c);
id<MTLBuffer> maskBuffer = use_mask ? getMTLBufferStorage(mask_c) : nil;
id<MTLBuffer> outputBuffer = getMTLBufferStorage(columns);

id<MTLDevice> device = MPSDevice::getInstance()->device();
std::string kernelName = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> pipelineState = mps::visionPipelineState(device, kernelName);

int num_kernels = in_channels * out_h * out_w * batch;
NSUInteger threadsPerThreadgroup = pipelineState.maxTotalThreadsPerThreadgroup;
NSUInteger threadgroups = (num_kernels + threadsPerThreadgroup - 1) / threadsPerThreadgroup;
MTLSize threadGroupSize = MTLSizeMake(threadsPerThreadgroup, 1, 1);
MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1);

MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^{
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
[computeEncoder setComputePipelineState:pipelineState];
at::native::mps::mtl_setArgs(computeEncoder, inputBuffer, offsetBuffer, maskBuffer,
std::array<uint32_t, 2>{in_h, in_w},
std::array<uint32_t, 2>{weight_h, weight_w},
std::array<uint32_t, 2>{pad_h_u, pad_w_u},
std::array<uint32_t, 2>{stride_h_u, stride_w_u},
std::array<uint32_t, 2>{dilation_h_u, dilation_w_u},
batch, in_channels, n_offset_grps,
std::array<uint32_t, 2>{out_h, out_w},
use_mask, outputBuffer);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
}
});
int in_channels_per_grp = in_channels / n_weight_grps;
int out_channels_per_grp = out_channels / n_weight_grps;
auto weight_grouped = weight_c.view({n_weight_grps, out_channels_per_grp, in_channels_per_grp, weight_h, weight_w});
auto columns_grouped = columns.view({n_weight_grps,
(in_channels * weight_h * weight_w) / n_weight_grps,
batch * out_h * out_w});
auto weight_reshaped = weight_grouped.reshape({n_weight_grps, out_channels_per_grp, -1});
auto out_grouped = at::bmm(weight_reshaped, columns_grouped);
auto out = out_grouped.reshape({n_weight_grps * out_channels_per_grp, batch, out_h, out_w})
.transpose(0, 1);
return out + bias_c.view({1, out_channels, 1, 1});
}

} // namespace

TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
m.impl(
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
TORCH_FN(deform_conv2d_forward_kernel));
}

} // namespace ops
} // namespace vision
Loading