Skip to content

Commit 904dad4

Browse files
More rotated bboxes transforms (#9095)
1 parent 428a54c commit 904dad4

File tree

4 files changed

+263
-74
lines changed

4 files changed

+263
-74
lines changed

test/common_utils.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,13 @@ def sample_position(values, max_value):
417417
format = tv_tensors.BoundingBoxFormat[format]
418418

419419
dtype = dtype or torch.float32
420+
int_dtype = dtype in (
421+
torch.uint8,
422+
torch.int8,
423+
torch.int16,
424+
torch.int32,
425+
torch.int64,
426+
)
420427

421428
h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size)
422429
y = sample_position(h, canvas_size[0])
@@ -443,17 +450,17 @@ def sample_position(values, max_value):
443450
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
444451
r_rad = r * torch.pi / 180.0
445452
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
446-
x1, y1 = x, y
447-
x2 = x1 + w * cos
448-
y2 = y1 - w * sin
449-
x3 = x2 + h * sin
450-
y3 = y2 + h * cos
451-
x4 = x1 + h * sin
452-
y4 = y1 + h * cos
453+
x1 = torch.round(x) if int_dtype else x
454+
y1 = torch.round(y) if int_dtype else y
455+
x2 = torch.round(x1 + w * cos) if int_dtype else x1 + w * cos
456+
y2 = torch.round(y1 - w * sin) if int_dtype else y1 - w * sin
457+
x3 = torch.round(x2 + h * sin) if int_dtype else x2 + h * sin
458+
y3 = torch.round(y2 + h * cos) if int_dtype else y2 + h * cos
459+
x4 = torch.round(x1 + h * sin) if int_dtype else x1 + h * sin
460+
y4 = torch.round(y1 + h * cos) if int_dtype else y1 + h * cos
453461
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
454462
else:
455463
raise ValueError(f"Format {format} is not supported")
456-
457464
return tv_tensors.BoundingBoxes(
458465
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
459466
)

test/test_transforms_v2.py

Lines changed: 114 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from torchvision.transforms.functional import pil_modes_mapping, to_pil_image
5050
from torchvision.transforms.v2 import functional as F
5151
from torchvision.transforms.v2._utils import check_type, is_pure_tensor
52-
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs
52+
from torchvision.transforms.v2.functional._geometry import _get_perspective_coeffs, _parallelogram_to_bounding_boxes
5353
from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal
5454

5555

@@ -560,7 +560,9 @@ def affine_bounding_boxes(bounding_boxes):
560560
)
561561

562562

563-
def reference_affine_rotated_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
563+
def reference_affine_rotated_bounding_boxes_helper(
564+
bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True, flip=False
565+
):
564566
format = bounding_boxes.format
565567
canvas_size = new_canvas_size or bounding_boxes.canvas_size
566568

@@ -588,21 +590,34 @@ def affine_rotated_bounding_boxes(bounding_boxes):
588590
transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T)
589591
output = torch.tensor(
590592
[
591-
float(transformed_points[1, 0]),
592-
float(transformed_points[1, 1]),
593593
float(transformed_points[0, 0]),
594594
float(transformed_points[0, 1]),
595-
float(transformed_points[3, 0]),
596-
float(transformed_points[3, 1]),
595+
float(transformed_points[1, 0]),
596+
float(transformed_points[1, 1]),
597597
float(transformed_points[2, 0]),
598598
float(transformed_points[2, 1]),
599+
float(transformed_points[3, 0]),
600+
float(transformed_points[3, 1]),
599601
]
600602
)
601603

604+
output = output[[2, 3, 0, 1, 6, 7, 4, 5]] if flip else output
605+
output = _parallelogram_to_bounding_boxes(output)
606+
602607
output = F.convert_bounding_box_format(
603608
output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format
604609
)
605610

611+
if torch.is_floating_point(output) and dtype in (
612+
torch.uint8,
613+
torch.int8,
614+
torch.int16,
615+
torch.int32,
616+
torch.int64,
617+
):
618+
# it is better to round before cast
619+
output = torch.round(output)
620+
606621
if clamp:
607622
# It is important to clamp before casting, especially for CXCYWHR format, dtype=int64
608623
output = F.clamp_bounding_boxes(
@@ -707,7 +722,7 @@ def test_kernel_image(self, size, interpolation, use_max_size, antialias, dtype,
707722
check_scripted_vs_eager=not isinstance(size, int),
708723
)
709724

710-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
725+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
711726
@pytest.mark.parametrize("size", OUTPUT_SIZES)
712727
@pytest.mark.parametrize("use_max_size", [True, False])
713728
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@@ -725,6 +740,7 @@ def test_kernel_bounding_boxes(self, format, size, use_max_size, dtype, device):
725740
check_kernel(
726741
F.resize_bounding_boxes,
727742
bounding_boxes,
743+
format=format,
728744
canvas_size=bounding_boxes.canvas_size,
729745
size=size,
730746
**max_size_kwarg,
@@ -816,7 +832,7 @@ def test_image_correctness(self, size, interpolation, use_max_size, fn):
816832
self._check_output_size(image, actual, size=size, **max_size_kwarg)
817833
torch.testing.assert_close(actual, expected, atol=1, rtol=0)
818834

819-
def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=None):
835+
def _reference_resize_bounding_boxes(self, bounding_boxes, format, *, size, max_size=None):
820836
old_height, old_width = bounding_boxes.canvas_size
821837
new_height, new_width = self._compute_output_size(
822838
input_size=bounding_boxes.canvas_size, size=size, max_size=max_size
@@ -832,13 +848,19 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non
832848
],
833849
)
834850

835-
return reference_affine_bounding_boxes_helper(
851+
helper = (
852+
reference_affine_rotated_bounding_boxes_helper
853+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
854+
else reference_affine_bounding_boxes_helper
855+
)
856+
857+
return helper(
836858
bounding_boxes,
837859
affine_matrix=affine_matrix,
838860
new_canvas_size=(new_height, new_width),
839861
)
840862

841-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
863+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
842864
@pytest.mark.parametrize("size", OUTPUT_SIZES)
843865
@pytest.mark.parametrize("use_max_size", [True, False])
844866
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
@@ -849,7 +871,7 @@ def test_bounding_boxes_correctness(self, format, size, use_max_size, fn):
849871
bounding_boxes = make_bounding_boxes(format=format, canvas_size=self.INPUT_SIZE)
850872

851873
actual = fn(bounding_boxes, size=size, **max_size_kwarg)
852-
expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg)
874+
expected = self._reference_resize_bounding_boxes(bounding_boxes, format=format, size=size, **max_size_kwarg)
853875

854876
self._check_output_size(bounding_boxes, actual, size=size, **max_size_kwarg)
855877
torch.testing.assert_close(actual, expected)
@@ -1152,7 +1174,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.B
11521174
)
11531175

11541176
helper = (
1155-
reference_affine_rotated_bounding_boxes_helper
1177+
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True)
11561178
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
11571179
else reference_affine_bounding_boxes_helper
11581180
)
@@ -1257,7 +1279,7 @@ def test_kernel_image(self, param, value, dtype, device):
12571279
shear=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"],
12581280
center=_EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"],
12591281
)
1260-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1282+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
12611283
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
12621284
@pytest.mark.parametrize("device", cpu_and_cuda())
12631285
def test_kernel_bounding_boxes(self, param, value, format, dtype, device):
@@ -1399,14 +1421,22 @@ def _reference_affine_bounding_boxes(self, bounding_boxes, *, angle, translate,
13991421
if center is None:
14001422
center = [s * 0.5 for s in bounding_boxes.canvas_size[::-1]]
14011423

1402-
return reference_affine_bounding_boxes_helper(
1424+
affine_matrix = self._compute_affine_matrix(
1425+
angle=angle, translate=translate, scale=scale, shear=shear, center=center
1426+
)
1427+
1428+
helper = (
1429+
reference_affine_rotated_bounding_boxes_helper
1430+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
1431+
else reference_affine_bounding_boxes_helper
1432+
)
1433+
1434+
return helper(
14031435
bounding_boxes,
1404-
affine_matrix=self._compute_affine_matrix(
1405-
angle=angle, translate=translate, scale=scale, shear=shear, center=center
1406-
),
1436+
affine_matrix=affine_matrix,
14071437
)
14081438

1409-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1439+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
14101440
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
14111441
@pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"])
14121442
@pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])
@@ -1607,7 +1637,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.Bou
16071637
)
16081638

16091639
helper = (
1610-
reference_affine_rotated_bounding_boxes_helper
1640+
functools.partial(reference_affine_rotated_bounding_boxes_helper, flip=True)
16111641
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
16121642
else reference_affine_bounding_boxes_helper
16131643
)
@@ -2914,7 +2944,7 @@ def test_kernel_image(self, kwargs, dtype, device):
29142944
check_kernel(F.crop_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), **kwargs)
29152945

29162946
@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
2917-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
2947+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
29182948
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
29192949
@pytest.mark.parametrize("device", cpu_and_cuda())
29202950
def test_kernel_bounding_box(self, kwargs, format, dtype, device):
@@ -3059,12 +3089,15 @@ def _reference_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, w
30593089
[0, 1, -top],
30603090
],
30613091
)
3062-
return reference_affine_bounding_boxes_helper(
3063-
bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width)
3092+
helper = (
3093+
reference_affine_rotated_bounding_boxes_helper
3094+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
3095+
else reference_affine_bounding_boxes_helper
30643096
)
3097+
return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width))
30653098

30663099
@pytest.mark.parametrize("kwargs", CORRECTNESS_CROP_KWARGS)
3067-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
3100+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
30683101
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
30693102
@pytest.mark.parametrize("device", cpu_and_cuda())
30703103
def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device):
@@ -3077,7 +3110,7 @@ def test_functional_bounding_box_correctness(self, kwargs, format, dtype, device
30773110
assert_equal(F.get_size(actual), F.get_size(expected))
30783111

30793112
@pytest.mark.parametrize("output_size", [(17, 11), (11, 17), (11, 11)])
3080-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
3113+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
30813114
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
30823115
@pytest.mark.parametrize("device", cpu_and_cuda())
30833116
@pytest.mark.parametrize("seed", list(range(5)))
@@ -3099,7 +3132,7 @@ def test_transform_bounding_boxes_correctness(self, output_size, format, dtype,
30993132

31003133
expected = self._reference_crop_bounding_boxes(bounding_boxes, **params)
31013134

3102-
assert_equal(actual, expected)
3135+
torch.testing.assert_close(actual, expected)
31033136
assert_equal(F.get_size(actual), F.get_size(expected))
31043137

31053138
def test_errors(self):
@@ -3834,13 +3867,19 @@ def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, h
38343867
)
38353868
affine_matrix = (resize_affine_matrix @ crop_affine_matrix)[:2, :]
38363869

3837-
return reference_affine_bounding_boxes_helper(
3870+
helper = (
3871+
reference_affine_rotated_bounding_boxes_helper
3872+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
3873+
else reference_affine_bounding_boxes_helper
3874+
)
3875+
3876+
return helper(
38383877
bounding_boxes,
38393878
affine_matrix=affine_matrix,
38403879
new_canvas_size=size,
38413880
)
38423881

3843-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
3882+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
38443883
def test_functional_bounding_boxes_correctness(self, format):
38453884
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)
38463885

@@ -3849,7 +3888,7 @@ def test_functional_bounding_boxes_correctness(self, format):
38493888
bounding_boxes, **self.CROP_KWARGS, size=self.OUTPUT_SIZE
38503889
)
38513890

3852-
assert_equal(actual, expected)
3891+
torch.testing.assert_close(actual, expected)
38533892
assert_equal(F.get_size(actual), F.get_size(expected))
38543893

38553894
def test_transform_errors_warnings(self):
@@ -3914,7 +3953,7 @@ def test_kernel_image(self, param, value, dtype, device):
39143953
),
39153954
)
39163955

3917-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
3956+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
39183957
def test_kernel_bounding_boxes(self, format):
39193958
bounding_boxes = make_bounding_boxes(format=format)
39203959
check_kernel(
@@ -4034,12 +4073,15 @@ def _reference_pad_bounding_boxes(self, bounding_boxes, *, padding):
40344073
height = bounding_boxes.canvas_size[0] + top + bottom
40354074
width = bounding_boxes.canvas_size[1] + left + right
40364075

4037-
return reference_affine_bounding_boxes_helper(
4038-
bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width)
4076+
helper = (
4077+
reference_affine_rotated_bounding_boxes_helper
4078+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
4079+
else reference_affine_bounding_boxes_helper
40394080
)
4081+
return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=(height, width))
40404082

40414083
@pytest.mark.parametrize("padding", CORRECTNESS_PADDINGS)
4042-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
4084+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
40434085
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
40444086
@pytest.mark.parametrize("device", cpu_and_cuda())
40454087
@pytest.mark.parametrize("fn", [F.pad, transform_cls_to_functional(transforms.Pad)])
@@ -4049,7 +4091,7 @@ def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn):
40494091
actual = fn(bounding_boxes, padding=padding)
40504092
expected = self._reference_pad_bounding_boxes(bounding_boxes, padding=padding)
40514093

4052-
assert_equal(actual, expected)
4094+
torch.testing.assert_close(actual, expected)
40534095

40544096

40554097
class TestCenterCrop:
@@ -4068,7 +4110,7 @@ def test_kernel_image(self, output_size, dtype, device):
40684110
)
40694111

40704112
@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
4071-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
4113+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
40724114
def test_kernel_bounding_boxes(self, output_size, format):
40734115
bounding_boxes = make_bounding_boxes(self.INPUT_SIZE, format=format)
40744116
check_kernel(
@@ -4142,12 +4184,15 @@ def _reference_center_crop_bounding_boxes(self, bounding_boxes, output_size):
41424184
[0, 1, -top],
41434185
],
41444186
)
4145-
return reference_affine_bounding_boxes_helper(
4146-
bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=output_size
4187+
helper = (
4188+
reference_affine_rotated_bounding_boxes_helper
4189+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
4190+
else reference_affine_bounding_boxes_helper
41474191
)
4192+
return helper(bounding_boxes, affine_matrix=affine_matrix, new_canvas_size=output_size)
41484193

41494194
@pytest.mark.parametrize("output_size", OUTPUT_SIZES)
4150-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
4195+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
41514196
@pytest.mark.parametrize("dtype", [torch.int64, torch.float32])
41524197
@pytest.mark.parametrize("device", cpu_and_cuda())
41534198
@pytest.mark.parametrize("fn", [F.center_crop, transform_cls_to_functional(transforms.CenterCrop)])
@@ -4157,7 +4202,7 @@ def test_bounding_boxes_correctness(self, output_size, format, dtype, device, fn
41574202
actual = fn(bounding_boxes, output_size)
41584203
expected = self._reference_center_crop_bounding_boxes(bounding_boxes, output_size)
41594204

4160-
assert_equal(actual, expected)
4205+
torch.testing.assert_close(actual, expected)
41614206

41624207

41634208
class TestPerspective:
@@ -5894,6 +5939,37 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t
58945939
assert out_label == label
58955940

58965941

5942+
@pytest.mark.parametrize("input_size", [(17, 11), (11, 17), (11, 11)])
5943+
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
5944+
@pytest.mark.parametrize("device", cpu_and_cuda())
5945+
def test_parallelogram_to_bounding_boxes(input_size, dtype, device):
5946+
# Assert that applying `_parallelogram_to_bounding_boxes` to rotated boxes
5947+
# does not modify the input.
5948+
bounding_boxes = make_bounding_boxes(
5949+
input_size, format=tv_tensors.BoundingBoxFormat.XYXYXYXY, dtype=dtype, device=device
5950+
)
5951+
actual = _parallelogram_to_bounding_boxes(bounding_boxes)
5952+
torch.testing.assert_close(actual, bounding_boxes, rtol=0, atol=1)
5953+
5954+
# Test the transformation of two simple parallelograms.
5955+
# 1---2 1----2
5956+
# / / -> | |
5957+
# 4---3 4----3
5958+
5959+
# 1---2 1----2
5960+
# \ \ -> | |
5961+
# 4---3 4----3
5962+
parallelogram = torch.tensor([[1, 0, 4, 0, 3, 2, 0, 2], [0, 0, 3, 0, 4, 2, 1, 2]])
5963+
expected = torch.tensor(
5964+
[
5965+
[0, 0, 4, 0, 4, 2, 0, 2],
5966+
[0, 0, 4, 0, 4, 2, 0, 2],
5967+
]
5968+
)
5969+
actual = _parallelogram_to_bounding_boxes(parallelogram)
5970+
assert_equal(actual, expected)
5971+
5972+
58975973
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, tv_tensors.Image))
58985974
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
58995975
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage))

0 commit comments

Comments
 (0)