Skip to content

Commit cbdad29

Browse files
authored
feat: support aten index_put converter for accumulate=False (#2880)
1 parent 2eac0bc commit cbdad29

File tree

3 files changed

+293
-1
lines changed

3 files changed

+293
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,67 @@ def aten_ops_select(
807807
)
808808

809809

810+
def index_put_validator(node: Node) -> bool:
811+
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
812+
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
813+
accumulate_valid = False
814+
else:
815+
accumulate_valid = True
816+
817+
# Retrieve input tensor's meta information
818+
input_meta = node.args[0].meta.get("tensor_meta")
819+
if not input_meta:
820+
_LOGGER.warning(
821+
"Meta information of input is missing. Unable to validate if broadcasting is needed, falling back to PyTorch operation."
822+
)
823+
return False
824+
825+
input_shape = input_meta.shape
826+
input_num_dims = len(input_shape)
827+
828+
# Check if broadcasting is valid
829+
indices_num_dims = len(node.args[1])
830+
if indices_num_dims == input_num_dims:
831+
broadcast_valid = True
832+
else:
833+
_LOGGER.debug(
834+
"We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions."
835+
)
836+
broadcast_valid = False
837+
838+
# Return validation result
839+
return accumulate_valid and broadcast_valid
840+
841+
842+
@dynamo_tensorrt_converter(
843+
torch.ops.aten.index_put.default,
844+
capability_validator=index_put_validator,
845+
)
846+
@enforce_tensor_types(
847+
{
848+
0: (TRTTensor,),
849+
2: (TRTTensor,),
850+
}
851+
)
852+
def aten_ops_index_put(
853+
ctx: ConversionContext,
854+
target: Target,
855+
args: Tuple[Argument, ...],
856+
kwargs: Dict[str, Argument],
857+
name: str,
858+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
859+
return impl.select.index_put_converter(
860+
ctx,
861+
target,
862+
SourceIR.ATEN,
863+
name,
864+
args[0],
865+
args[1],
866+
args[2],
867+
args_bounds_check(args, 3, False),
868+
)
869+
870+
810871
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True)
811872
@enforce_tensor_types(
812873
{

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from torch.fx.node import Target
88
from torch_tensorrt.dynamo._SourceIR import SourceIR
9+
from torch_tensorrt.dynamo.conversion import impl
910
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1011
from torch_tensorrt.dynamo.conversion.converter_utils import (
1112
broadcastable,
@@ -410,7 +411,7 @@ def scatter(
410411
dim = get_positive_dim(dim, len(input_shape))
411412
src_tensor = src
412413
# scatter.value
413-
if isinstance(src, int) or isinstance(src, float):
414+
if isinstance(src, (int, float)):
414415
src_tensor = get_trt_tensor(
415416
ctx, src * np.ones(index_shape_list), name + "_value_tensor"
416417
)
@@ -446,3 +447,41 @@ def gather(
446447
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)
447448
out = gather_layer.get_output(0)
448449
return out
450+
451+
452+
def index_put_converter(
453+
ctx: ConversionContext,
454+
target: Target,
455+
source_ir: Optional[SourceIR],
456+
name: str,
457+
input_tensor: TRTTensor,
458+
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
459+
values: TRTTensor,
460+
accumulate: bool = False,
461+
) -> TRTTensor:
462+
# Reshape indices to add an extra dimension if necessary (indices is a Tuple of ITensors)
463+
reshaped_indices = []
464+
for i, each_input in enumerate(indices):
465+
if not isinstance(each_input, TRTTensor):
466+
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
467+
each_input = impl.shuffle.reshape(
468+
ctx,
469+
target,
470+
source_ir,
471+
f"{name}_reshape_{i}",
472+
each_input,
473+
(-1, 1), # Reshape to (N, 1)
474+
)
475+
reshaped_indices.append(each_input)
476+
477+
# Concatenate along the second dimension (columns)
478+
indices_cat = impl.cat.cat(
479+
ctx, target, source_ir, f"{name}_cat", reshaped_indices, dim=1
480+
)
481+
482+
scatter_layer = ctx.net.add_scatter(
483+
input_tensor, indices_cat, values, trt.ScatterMode.ND
484+
)
485+
scatter_layer.axis = 0
486+
set_layer_name(scatter_layer, target, f"{name}_scatter_layer", source_ir)
487+
return scatter_layer.get_output(0)
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import torch
2+
from parameterized import param, parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
5+
from .harness import DispatchTestCase
6+
7+
8+
class TestIndexPutConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
param(
12+
test_name="1d_indices_single",
13+
source_tensor=torch.zeros([5], dtype=torch.int32),
14+
indices_tensor=(torch.tensor([0], dtype=torch.int32),),
15+
value_tensor=torch.tensor([1], dtype=torch.int32),
16+
),
17+
param(
18+
test_name="1d_indices_multiple",
19+
source_tensor=torch.zeros([5], dtype=torch.int32),
20+
indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),),
21+
value_tensor=torch.tensor([1, 3], dtype=torch.int32),
22+
),
23+
param(
24+
test_name="2d_indices_single",
25+
source_tensor=torch.zeros([5, 5], dtype=torch.int32),
26+
indices_tensor=(
27+
torch.tensor([2], dtype=torch.int32),
28+
torch.tensor([0], dtype=torch.int32),
29+
),
30+
value_tensor=torch.tensor([3], dtype=torch.int32),
31+
),
32+
param(
33+
test_name="2d_indices_multiple",
34+
source_tensor=torch.zeros([5, 5], dtype=torch.int32),
35+
indices_tensor=(
36+
torch.tensor([0, 2, 2], dtype=torch.int32),
37+
torch.tensor([2, 0, 2], dtype=torch.int32),
38+
),
39+
value_tensor=torch.tensor([1, 3, 4], dtype=torch.int32),
40+
),
41+
param(
42+
test_name="3d_indices_single",
43+
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
44+
indices_tensor=(
45+
torch.tensor([1], dtype=torch.int32),
46+
torch.tensor([2], dtype=torch.int32),
47+
torch.tensor([2], dtype=torch.int32),
48+
),
49+
value_tensor=torch.tensor([7], dtype=torch.int32),
50+
),
51+
param(
52+
test_name="3d_indices_multiple",
53+
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
54+
indices_tensor=(
55+
torch.tensor([0, 1, 1], dtype=torch.int32),
56+
torch.tensor([1, 2, 1], dtype=torch.int32),
57+
torch.tensor([2, 0, 2], dtype=torch.int32),
58+
),
59+
value_tensor=torch.tensor([5, 7, 2], dtype=torch.int32),
60+
),
61+
param(
62+
test_name="4d_indices_single",
63+
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
64+
indices_tensor=(
65+
torch.tensor([1], dtype=torch.int32),
66+
torch.tensor([1], dtype=torch.int32),
67+
torch.tensor([0], dtype=torch.int32),
68+
torch.tensor([1], dtype=torch.int32),
69+
),
70+
value_tensor=torch.tensor([5], dtype=torch.int32),
71+
),
72+
param(
73+
test_name="4d_indices_multiple",
74+
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
75+
indices_tensor=(
76+
torch.tensor([0, 1], dtype=torch.int32),
77+
torch.tensor([1, 1], dtype=torch.int32),
78+
torch.tensor([1, 0], dtype=torch.int32),
79+
torch.tensor([1, 0], dtype=torch.int32),
80+
),
81+
value_tensor=torch.tensor([5, 7], dtype=torch.int32),
82+
),
83+
param(
84+
test_name="negative_indices",
85+
source_tensor=torch.zeros([5, 5], dtype=torch.int32),
86+
indices_tensor=(
87+
torch.tensor([-1, -2], dtype=torch.int32),
88+
torch.tensor([2, 0], dtype=torch.int32),
89+
),
90+
value_tensor=torch.tensor([1, 3], dtype=torch.int32),
91+
),
92+
param(
93+
test_name="mixed_indices",
94+
source_tensor=torch.zeros([4, 4], dtype=torch.int32),
95+
indices_tensor=(
96+
torch.tensor([0, 1, -1, -2], dtype=torch.int32),
97+
torch.tensor([0, -1, 2, 1], dtype=torch.int32),
98+
),
99+
value_tensor=torch.tensor([2, 4, 6, 8], dtype=torch.int32),
100+
),
101+
param(
102+
test_name="1d_indices_float",
103+
source_tensor=torch.zeros([5], dtype=torch.float32),
104+
indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),),
105+
value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32),
106+
),
107+
param(
108+
test_name="2d_indices_float",
109+
source_tensor=torch.zeros([5, 5], dtype=torch.float32),
110+
indices_tensor=(
111+
torch.tensor([0, 2], dtype=torch.int32),
112+
torch.tensor([2, 0], dtype=torch.int32),
113+
),
114+
value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32),
115+
),
116+
param(
117+
test_name="3d_indices_float",
118+
source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32),
119+
indices_tensor=(
120+
torch.tensor([0, 1], dtype=torch.int32),
121+
torch.tensor([1, 2], dtype=torch.int32),
122+
torch.tensor([2, 0], dtype=torch.int32),
123+
),
124+
value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32),
125+
),
126+
param(
127+
test_name="4d_indices_float",
128+
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32),
129+
indices_tensor=(
130+
torch.tensor([0, 1], dtype=torch.int32),
131+
torch.tensor([1, 0], dtype=torch.int32),
132+
torch.tensor([0, 1], dtype=torch.int32),
133+
torch.tensor([1, 0], dtype=torch.int32),
134+
),
135+
value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32),
136+
),
137+
# param(
138+
# test_name="3d_indices_float_broadcase_index",
139+
# source_tensor=torch.zeros([3, 3, 3], dtype = torch.int32),
140+
# indices_tensor=(
141+
# torch.tensor([0,1], dtype=torch.int32),
142+
# torch.tensor([0,1], dtype=torch.int32),
143+
# ),
144+
# value_tensor=torch.tensor([10], dtype = torch.int32),
145+
# ),
146+
# param(
147+
# test_name="2d_indices_accumulate_True",
148+
# source_tensor=torch.zeros([5, 5], dtype=torch.int32),
149+
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
150+
# value_tensor=torch.tensor([1, 2], dtype=torch.int32),
151+
# accumulate=True,
152+
# ),
153+
# param(
154+
# test_name="3d_indices_accumulate_True",
155+
# source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
156+
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)),
157+
# value_tensor=torch.tensor([1, 2], dtype=torch.int32),
158+
# accumulate=True,
159+
# ),
160+
# param(
161+
# test_name="4d_indices_accumulate_True",
162+
# source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
163+
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
164+
# value_tensor=torch.tensor([1, 2], dtype=torch.int32),
165+
# accumulate=True,
166+
# ),
167+
]
168+
)
169+
def test_index_put(
170+
self, test_name, source_tensor, indices_tensor, value_tensor, accumulate=False
171+
):
172+
@torch._dynamo.assume_constant_result
173+
def get_indices_tensor():
174+
return indices_tensor
175+
176+
class TestIndexPut(torch.nn.Module):
177+
def forward(self, source_tensor, value_tensor):
178+
indices_tensor_const = get_indices_tensor()
179+
return torch.ops.aten.index_put.default(
180+
source_tensor, indices_tensor_const, value_tensor, accumulate
181+
)
182+
183+
self.run_test(
184+
TestIndexPut(),
185+
inputs=[source_tensor, value_tensor],
186+
enable_passes=True,
187+
use_dynamo_tracer=True,
188+
)
189+
190+
191+
if __name__ == "__main__":
192+
run_tests()

0 commit comments

Comments
 (0)