-
Notifications
You must be signed in to change notification settings - Fork 363
feat: support aten index_put converter for accumulate=False #2880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
cd790f0
1d index_put converter with tuple of Tensor index
chohk88 ab757e6
feat: support aten.index_put converter except accumulate True
chohk88 9953051
feat: Add validator to always return False due to 'use_dynamo_tracer=…
chohk88 c72222a
chore: minor linting issue
chohk88 5b3b83d
feat: add assume_constant_result decorator to handle tuple of tensors
chohk88 193349a
chore: minor linting
chohk88 4267d1f
chore: revise type hinting and delete unused package
chohk88 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import torch | ||
from parameterized import param, parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
|
||
from .harness import DispatchTestCase | ||
|
||
|
||
class TestIndexPutConverter(DispatchTestCase): | ||
@parameterized.expand( | ||
[ | ||
param( | ||
test_name="1d_indices_single", | ||
source_tensor=torch.zeros([5], dtype=torch.int32), | ||
indices_tensor=(torch.tensor([0], dtype=torch.int32),), | ||
value_tensor=torch.tensor([1], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="1d_indices_multiple", | ||
source_tensor=torch.zeros([5], dtype=torch.int32), | ||
indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),), | ||
value_tensor=torch.tensor([1, 3], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="2d_indices_single", | ||
source_tensor=torch.zeros([5, 5], dtype=torch.int32), | ||
indices_tensor=( | ||
torch.tensor([2], dtype=torch.int32), | ||
torch.tensor([0], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([3], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="2d_indices_multiple", | ||
source_tensor=torch.zeros([5, 5], dtype=torch.int32), | ||
indices_tensor=( | ||
torch.tensor([0, 2, 2], dtype=torch.int32), | ||
torch.tensor([2, 0, 2], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([1, 3, 4], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="3d_indices_single", | ||
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32), | ||
indices_tensor=( | ||
torch.tensor([1], dtype=torch.int32), | ||
torch.tensor([2], dtype=torch.int32), | ||
torch.tensor([2], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([7], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="3d_indices_multiple", | ||
source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32), | ||
indices_tensor=( | ||
torch.tensor([0, 1, 1], dtype=torch.int32), | ||
torch.tensor([1, 2, 1], dtype=torch.int32), | ||
torch.tensor([2, 0, 2], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([5, 7, 2], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="4d_indices_single", | ||
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32), | ||
indices_tensor=( | ||
torch.tensor([1], dtype=torch.int32), | ||
torch.tensor([1], dtype=torch.int32), | ||
torch.tensor([0], dtype=torch.int32), | ||
torch.tensor([1], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([5], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="4d_indices_multiple", | ||
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32), | ||
indices_tensor=( | ||
torch.tensor([0, 1], dtype=torch.int32), | ||
torch.tensor([1, 1], dtype=torch.int32), | ||
torch.tensor([1, 0], dtype=torch.int32), | ||
torch.tensor([1, 0], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([5, 7], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="negative_indices", | ||
source_tensor=torch.zeros([5, 5], dtype=torch.int32), | ||
indices_tensor=( | ||
torch.tensor([-1, -2], dtype=torch.int32), | ||
torch.tensor([2, 0], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([1, 3], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="mixed_indices", | ||
source_tensor=torch.zeros([4, 4], dtype=torch.int32), | ||
indices_tensor=( | ||
torch.tensor([0, 1, -1, -2], dtype=torch.int32), | ||
torch.tensor([0, -1, 2, 1], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([2, 4, 6, 8], dtype=torch.int32), | ||
), | ||
param( | ||
test_name="1d_indices_float", | ||
source_tensor=torch.zeros([5], dtype=torch.float32), | ||
indices_tensor=(torch.tensor([0, 3], dtype=torch.int32),), | ||
value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32), | ||
), | ||
param( | ||
test_name="2d_indices_float", | ||
source_tensor=torch.zeros([5, 5], dtype=torch.float32), | ||
indices_tensor=( | ||
torch.tensor([0, 2], dtype=torch.int32), | ||
torch.tensor([2, 0], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([1.5, 3.5], dtype=torch.float32), | ||
), | ||
param( | ||
test_name="3d_indices_float", | ||
source_tensor=torch.zeros([3, 3, 3], dtype=torch.float32), | ||
indices_tensor=( | ||
torch.tensor([0, 1], dtype=torch.int32), | ||
torch.tensor([1, 2], dtype=torch.int32), | ||
torch.tensor([2, 0], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32), | ||
), | ||
param( | ||
test_name="4d_indices_float", | ||
source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.float32), | ||
indices_tensor=( | ||
torch.tensor([0, 1], dtype=torch.int32), | ||
torch.tensor([1, 0], dtype=torch.int32), | ||
torch.tensor([0, 1], dtype=torch.int32), | ||
torch.tensor([1, 0], dtype=torch.int32), | ||
), | ||
value_tensor=torch.tensor([5.5, 7.5], dtype=torch.float32), | ||
), | ||
# param( | ||
# test_name="3d_indices_float_broadcase_index", | ||
# source_tensor=torch.zeros([3, 3, 3], dtype = torch.int32), | ||
# indices_tensor=( | ||
# torch.tensor([0,1], dtype=torch.int32), | ||
# torch.tensor([0,1], dtype=torch.int32), | ||
# ), | ||
# value_tensor=torch.tensor([10], dtype = torch.int32), | ||
# ), | ||
# param( | ||
# test_name="2d_indices_accumulate_True", | ||
# source_tensor=torch.zeros([5, 5], dtype=torch.int32), | ||
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)), | ||
# value_tensor=torch.tensor([1, 2], dtype=torch.int32), | ||
# accumulate=True, | ||
# ), | ||
# param( | ||
# test_name="3d_indices_accumulate_True", | ||
# source_tensor=torch.zeros([3, 3, 3], dtype=torch.int32), | ||
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([2, 2], dtype=torch.int32)), | ||
# value_tensor=torch.tensor([1, 2], dtype=torch.int32), | ||
# accumulate=True, | ||
# ), | ||
# param( | ||
# test_name="4d_indices_accumulate_True", | ||
# source_tensor=torch.zeros([2, 2, 2, 2], dtype=torch.int32), | ||
# indices_tensor=(torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32), torch.tensor([0, 0], dtype=torch.int32), torch.tensor([1, 1], dtype=torch.int32)), | ||
# value_tensor=torch.tensor([1, 2], dtype=torch.int32), | ||
# accumulate=True, | ||
# ), | ||
] | ||
) | ||
def test_index_put( | ||
self, test_name, source_tensor, indices_tensor, value_tensor, accumulate=False | ||
): | ||
@torch._dynamo.assume_constant_result | ||
def get_indices_tensor(): | ||
return indices_tensor | ||
|
||
class TestIndexPut(torch.nn.Module): | ||
def forward(self, source_tensor, value_tensor): | ||
indices_tensor_const = get_indices_tensor() | ||
return torch.ops.aten.index_put.default( | ||
source_tensor, indices_tensor_const, value_tensor, accumulate | ||
) | ||
|
||
self.run_test( | ||
TestIndexPut(), | ||
inputs=[source_tensor, value_tensor], | ||
enable_passes=True, | ||
use_dynamo_tracer=True, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
indices
is possible to beITensor
per the schema, you may not be able to iterate an ITensor.In the test case, you can try to change the line 173 to
inputs=[source_tensor, indices_tensor, value_tensor],
.It's kind of similar to the
offsets
in the annoyingembedding_bag
. You can think about how to use native TRT Layers to do this, like ILoop.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides, what blocks you when
accumulate=True
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for your review. When
indices
is atorch.tensor
, an error occurs in PyTorch as shown in the example below. This situation is somewhat different fromembedding_bag
. It is a case where the input is atuple of tensors
, which we discussed earlier.If you look at the example, the
index_put_
function throws an error whenindices
is oftorch.tensor
type and only works correctly whenindices
is atuple
orlist
.Therefore,
indices
can be iterated over for loop and I did not use a for loop foreach_input
since it is anITensor
. If I am mistaken, your comments would be very helpful.One more question I have is about the type definition of indices when it is a tuple of tensors. Is it correct to define indices as
Union[TRTTensor, Tuple[TRTTensor, ...]]
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
accumulate=True
, if there are duplicate pairs of index inindices
, the correspondingvalues
should be summed and then removed from the elements. Therefore, I aimed to obtainindices
without duplicated pairs and corresponding modifiedvalues
, and then use these to input into the scatter layer. However, I encountered difficulties in implementing the for loop to check for duplicate pairs of index inindices
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed explanations! Yes you are right, the
indices
should be list or tuple, and thus it could be iterated over. Then your current implementation LGTM.I think it could be
Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]
since a single TRTTensor cannot be iterated and per the schema, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you say
accumulate=True
is causing issue, I believe the duplicate indices causes issues. I faced the same inscatter_reduce
and I believe advanced indexing would be the way to deal with it (lengthy code that would be I believe :( ). Do you have any other ideas?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have write a validator to handle the
accumulate=True
case. And I have created a separate issue for implementing the converter foraccumulate=True
. It would be great to share ideas and work together on this.