Skip to content

Commit 193abf8

Browse files
committed
gather converter
1 parent 643fb7b commit 193abf8

File tree

3 files changed

+110
-1
lines changed

3 files changed

+110
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+19
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,25 @@ def aten_ops_clamp(
751751
)
752752

753753

754+
@dynamo_tensorrt_converter(torch.ops.aten.gather.default)
755+
@enforce_tensor_types(
756+
{
757+
0: (TRTTensor,),
758+
2: (TRTTensor,),
759+
}
760+
)
761+
def aten_ops_gather(
762+
ctx: ConversionContext,
763+
target: Target,
764+
args: Tuple[Argument, ...],
765+
kwargs: Dict[str, Argument],
766+
name: str,
767+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
768+
return impl.select.gather(
769+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
770+
)
771+
772+
754773
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
755774
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
756775
@enforce_tensor_types(

py/torch_tensorrt/dynamo/conversion/impl/select.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def index_select(
387387
dim = get_positive_dim(dim, len(input.shape))
388388
gather_layer = ctx.net.add_gather(input, index, axis=dim)
389389

390-
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)
390+
set_layer_name(gather_layer, target, f"{name}_gather_layer_default", source_ir)
391391

392392
return gather_layer.get_output(0)
393393

@@ -428,3 +428,21 @@ def scatter(
428428
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
429429
out = scatter_layer.get_output(0)
430430
return out
431+
432+
433+
def gather(
434+
ctx: ConversionContext,
435+
target: Target,
436+
source_ir: Optional[SourceIR],
437+
name: str,
438+
input: TRTTensor,
439+
dim: int,
440+
index: Union[TRTTensor, np.ndarray, torch.Tensor],
441+
) -> TRTTensor:
442+
input_shape = input.shape
443+
dim = get_positive_dim(dim, len(input_shape))
444+
gather_layer = ctx.net.add_gather(input, index, axis=dim)
445+
gather_layer.mode = trt.GatherMode.ELEMENT
446+
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)
447+
out = gather_layer.get_output(0)
448+
return out
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestGatherValueConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(
13+
"gather_zero_dim_indexOne_constant_value",
14+
0,
15+
torch.tensor([[0, 1, 2, 0]]),
16+
),
17+
(
18+
"gather_zero_dim_indexTwo_constant_value",
19+
0,
20+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
21+
),
22+
(
23+
"gather_one_dim_indexOne_constant_value",
24+
1,
25+
torch.tensor([[0, 1, 2, 0]]),
26+
),
27+
(
28+
"gather_one_dim_indexTwo_costant_value",
29+
1,
30+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
31+
),
32+
]
33+
)
34+
def test_gather_index_constant(self, _, dim, index):
35+
class TestModule(torch.nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
39+
def forward(self, input):
40+
return torch.ops.aten.gather.default(input, dim, index)
41+
42+
input = torch.zeros(3, 5, dtype=torch.int32)
43+
inputs = [input]
44+
self.run_test(TestModule(), inputs)
45+
46+
@parameterized.expand(
47+
[
48+
("gather_zero_dim_indexOne_value", 0, torch.tensor([[0, 1, 2, 0]])),
49+
(
50+
"gather_zero_dim_indexTwo_value",
51+
0,
52+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
53+
),
54+
("gather_one_dim_indexOne_value", 1, torch.tensor([[0, 1, 2, 0]])),
55+
(
56+
"gather_one_dim_indexTwo_value",
57+
1,
58+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
59+
),
60+
]
61+
)
62+
def test_gather_index_input(self, _, dim, index):
63+
class TestModule(torch.nn.Module):
64+
def __init__(self):
65+
super().__init__()
66+
67+
def forward(self, input, index):
68+
return torch.ops.aten.gather.default(input, dim, index)
69+
70+
input = torch.zeros(3, 5, dtype=torch.int32)
71+
inputs = [input, index]
72+
self.run_test(TestModule(), inputs)

0 commit comments

Comments
 (0)