Skip to content

feat: support aten index_put converter for accumulate=False #2880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,67 @@ def aten_ops_select(
)


def index_put_validator(node: Node) -> bool:
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
accumulate_valid = False
else:
accumulate_valid = True

# Retrieve input tensor's meta information
input_meta = node.args[0].meta.get("tensor_meta")
if not input_meta:
_LOGGER.warning(
"Meta information of input is missing. Unable to validate if broadcasting is needed, falling back to PyTorch operation."
)
return False

input_shape = input_meta.shape
input_num_dims = len(input_shape)

# Check if broadcasting is valid
indices_num_dims = len(node.args[1])
if indices_num_dims == input_num_dims:
broadcast_valid = True
else:
_LOGGER.debug(
"We do not support broadcasting when the number of index dimensions does not match the number of input tensor dimensions."
)
broadcast_valid = False

# Return validation result
return accumulate_valid and broadcast_valid


@dynamo_tensorrt_converter(
torch.ops.aten.index_put.default,
capability_validator=index_put_validator,
)
@enforce_tensor_types(
{
0: (TRTTensor,),
2: (TRTTensor,),
}
)
def aten_ops_index_put(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.select.index_put_converter(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args_bounds_check(args, 3, False),
)


@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
Expand Down
41 changes: 40 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
broadcastable,
Expand Down Expand Up @@ -410,7 +411,7 @@ def scatter(
dim = get_positive_dim(dim, len(input_shape))
src_tensor = src
# scatter.value
if isinstance(src, int) or isinstance(src, float):
if isinstance(src, (int, float)):
src_tensor = get_trt_tensor(
ctx, src * np.ones(index_shape_list), name + "_value_tensor"
)
Expand Down Expand Up @@ -446,3 +447,41 @@ def gather(
set_layer_name(gather_layer, target, name + "_gather_layer_element", source_ir)
out = gather_layer.get_output(0)
return out


def index_put_converter(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_tensor: TRTTensor,
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
values: TRTTensor,
accumulate: bool = False,
) -> TRTTensor:
# Reshape indices to add an extra dimension if necessary (indices is a Tuple of ITensors)
reshaped_indices = []
for i, each_input in enumerate(indices):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since indices is possible to be ITensor per the schema, you may not be able to iterate an ITensor.
In the test case, you can try to change the line 173 to inputs=[source_tensor, indices_tensor, value_tensor],.
It's kind of similar to the offsets in the annoying embedding_bag. You can think about how to use native TRT Layers to do this, like ILoop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides, what blocks you when accumulate=True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for your review. When indices is a torch.tensor, an error occurs in PyTorch as shown in the example below. This situation is somewhat different from embedding_bag. It is a case where the input is a tuple of tensors, which we discussed earlier.

If you look at the example, the index_put_ function throws an error when indices is of torch.tensor type and only works correctly when indices is a tuple or list.

image

Therefore, indices can be iterated over for loop and I did not use a for loop for each_input since it is an ITensor. If I am mistaken, your comments would be very helpful.

One more question I have is about the type definition of indices when it is a tuple of tensors. Is it correct to define indices as Union[TRTTensor, Tuple[TRTTensor, ...]]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When accumulate=True, if there are duplicate pairs of index in indices, the corresponding values should be summed and then removed from the elements. Therefore, I aimed to obtain indices without duplicated pairs and corresponding modified values, and then use these to input into the scatter layer. However, I encountered difficulties in implementing the for loop to check for duplicate pairs of index in indices.

Copy link
Collaborator

@zewenli98 zewenli98 Jun 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed explanations! Yes you are right, the indices should be list or tuple, and thus it could be iterated over. Then your current implementation LGTM.

One more question I have is about the type definition of indices when it is a tuple of tensors. Is it correct to define indices as Union[TRTTensor, Tuple[TRTTensor, ...]]?

I think it could be Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]] since a single TRTTensor cannot be iterated and per the schema, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you say accumulate=True is causing issue, I believe the duplicate indices causes issues. I faced the same in scatter_reduce and I believe advanced indexing would be the way to deal with it (lengthy code that would be I believe :( ). Do you have any other ideas?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have write a validator to handle the accumulate=True case. And I have created a separate issue for implementing the converter for accumulate=True. It would be great to share ideas and work together on this.

if not isinstance(each_input, TRTTensor):
each_input = get_trt_tensor(ctx, each_input, f"{name}_tensor_{i}")
each_input = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_{i}",
each_input,
(-1, 1), # Reshape to (N, 1)
)
reshaped_indices.append(each_input)

# Concatenate along the second dimension (columns)
indices_cat = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat", reshaped_indices, dim=1
)

scatter_layer = ctx.net.add_scatter(
input_tensor, indices_cat, values, trt.ScatterMode.ND
)
scatter_layer.axis = 0
set_layer_name(scatter_layer, target, f"{name}_scatter_layer", source_ir)
return scatter_layer.get_output(0)
192 changes: 192 additions & 0 deletions tests/py/dynamo/conversion/test_index_put_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import torch
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestIndexPutConverter(DispatchTestCase):
@parameterized.expand(
[
param(
test_name="1d_indices_single",
source_tensor=torch.zeros([5], dtype=torch.int32),
indices_tensor=(torch.tensor([0], dtype=torch.int32),),
value_tensor=torch.tensor([1], dtype=torch.int32),
),
param(
test_name="1d_indices_multiple",
source_tensor=torch.zeros([5], dtype=torch.int32),
indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),),
value_tensor=torch.tensor([1, 3], dtype=torch.int32),
),
param(
test_name="2d_indices_single",
source_tensor=torch.zeros([5, 5], dtype=torch.int32),
indices_tensor=(
torch.tensor([2], dtype=torch.int32),
torch.tensor([0], dtype=torch.int32),
),
value_tensor=torch.tensor([3], dtype=torch.int32),
),
param(
test_name="2d_indices_multiple",
source_tensor=torch.zeros([5, 5], dtype=torch.int32),
indices_tensor=(
torch.tensor([0, 2, 2], dtype=torch.int32),
torch.tensor([2, 0, 2], dtype=torch.int32),
),
value_tensor=torch.tensor([1, 3, 4], dtype=torch.int32),
),
param(
test_name="3d_indices_single",
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
indices_tensor=(
torch.tensor([1], dtype=torch.int32),
torch.tensor([2], dtype=torch.int32),
torch.tensor([2], dtype=torch.int32),
),
value_tensor=torch.tensor([7], dtype=torch.int32),
),
param(
test_name="3d_indices_multiple",
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
indices_tensor=(
torch.tensor([0, 1, 1], dtype=torch.int32),
torch.tensor([1, 2, 1], dtype=torch.int32),
torch.tensor([2, 0, 2], dtype=torch.int32),
),
value_tensor=torch.tensor([5, 7, 2], dtype=torch.int32),
),
param(
test_name="4d_indices_single",
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
indices_tensor=(
torch.tensor([1], dtype=torch.int32),
torch.tensor([1], dtype=torch.int32),
torch.tensor([0], dtype=torch.int32),
torch.tensor([1], dtype=torch.int32),
),
value_tensor=torch.tensor([5], dtype=torch.int32),
),
param(
test_name="4d_indices_multiple",
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
indices_tensor=(
torch.tensor([0, 1], dtype=torch.int32),
torch.tensor([1, 1], dtype=torch.int32),
torch.tensor([1, 0], dtype=torch.int32),
torch.tensor([1, 0], dtype=torch.int32),
),
value_tensor=torch.tensor([5, 7], dtype=torch.int32),
),
param(
test_name="negative_indices",
source_tensor=torch.zeros([5, 5], dtype=torch.int32),
indices_tensor=(
torch.tensor([-1, -2], dtype=torch.int32),
torch.tensor([2, 0], dtype=torch.int32),
),
value_tensor=torch.tensor([1, 3], dtype=torch.int32),
),
param(
test_name="mixed_indices",
source_tensor=torch.zeros([4, 4], dtype=torch.int32),
indices_tensor=(
torch.tensor([0, 1, -1, -2], dtype=torch.int32),
torch.tensor([0, -1, 2, 1], dtype=torch.int32),
),
value_tensor=torch.tensor([2, 4, 6, 8], dtype=torch.int32),
),
param(
test_name="1d_indices_float",
source_tensor=torch.zeros([5], dtype=torch.float32),
indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),),
value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32),
),
param(
test_name="2d_indices_float",
source_tensor=torch.zeros([5, 5], dtype=torch.float32),
indices_tensor=(
torch.tensor([0, 2], dtype=torch.int32),
torch.tensor([2, 0], dtype=torch.int32),
),
value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32),
),
param(
test_name="3d_indices_float",
source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32),
indices_tensor=(
torch.tensor([0, 1], dtype=torch.int32),
torch.tensor([1, 2], dtype=torch.int32),
torch.tensor([2, 0], dtype=torch.int32),
),
value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32),
),
param(
test_name="4d_indices_float",
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32),
indices_tensor=(
torch.tensor([0, 1], dtype=torch.int32),
torch.tensor([1, 0], dtype=torch.int32),
torch.tensor([0, 1], dtype=torch.int32),
torch.tensor([1, 0], dtype=torch.int32),
),
value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32),
),
# param(
# test_name="3d_indices_float_broadcase_index",
# source_tensor=torch.zeros([3, 3, 3], dtype = torch.int32),
# indices_tensor=(
# torch.tensor([0,1], dtype=torch.int32),
# torch.tensor([0,1], dtype=torch.int32),
# ),
# value_tensor=torch.tensor([10], dtype = torch.int32),
# ),
# param(
# test_name="2d_indices_accumulate_True",
# source_tensor=torch.zeros([5, 5], dtype=torch.int32),
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
# value_tensor=torch.tensor([1, 2], dtype=torch.int32),
# accumulate=True,
# ),
# param(
# test_name="3d_indices_accumulate_True",
# source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32),
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)),
# value_tensor=torch.tensor([1, 2], dtype=torch.int32),
# accumulate=True,
# ),
# param(
# test_name="4d_indices_accumulate_True",
# source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32),
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)),
# value_tensor=torch.tensor([1, 2], dtype=torch.int32),
# accumulate=True,
# ),
]
)
def test_index_put(
self, test_name, source_tensor, indices_tensor, value_tensor, accumulate=False
):
@torch._dynamo.assume_constant_result
def get_indices_tensor():
return indices_tensor

class TestIndexPut(torch.nn.Module):
def forward(self, source_tensor, value_tensor):
indices_tensor_const = get_indices_tensor()
return torch.ops.aten.index_put.default(
source_tensor, indices_tensor_const, value_tensor, accumulate
)

self.run_test(
TestIndexPut(),
inputs=[source_tensor, value_tensor],
enable_passes=True,
use_dynamo_tracer=True,
)


if __name__ == "__main__":
run_tests()
Loading