Skip to content

Commit b50d12f

Browse files
authored
small fix: Index validator enable int64 (#2642) (#2643)
1 parent 5eb323f commit b50d12f

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

examples/dynamo/torch_compile_advanced_usage.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
4343
# For the default settings, we can simply call torch.compile
4444
# with the backend "torch_tensorrt", and run the model on an
4545
# input to cause compilation, as so:
46-
optimized_model = torch.compile(model, backend="torch_tensorrt")
46+
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False)
4747
optimized_model(*sample_inputs)
4848

4949
# %%
@@ -81,7 +81,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
8181

8282
# Run the model on an input to cause compilation, as so:
8383
optimized_model_custom = torch.compile(
84-
model_half, backend="torch_tensorrt", options=backend_kwargs
84+
model_half,
85+
backend="torch_tensorrt",
86+
options=backend_kwargs,
87+
dynamic=False,
8588
)
8689
optimized_model_custom(*sample_inputs_half)
8790

examples/dynamo/torch_compile_transformers_example.py

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
optimized_model = torch.compile(
6262
model,
6363
backend="torch_tensorrt",
64+
dynamic=False,
6465
options=compilation_kwargs,
6566
)
6667
optimized_model(*inputs)

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def index_dtype_validator(node: Node) -> bool:
397397
for ind in index:
398398
if ind is not None:
399399
val = ind.meta.get("val")
400-
if val is not None and val.dtype != torch.int32:
400+
if val is not None and val.dtype not in (torch.int32, torch.int64):
401401
return False
402402
return True
403403

tests/py/dynamo/conversion/test_index_aten.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import operator
2-
31
import torch
42
import torch.nn as nn
53
from torch.testing._internal.common_utils import run_tests
6-
from torch_tensorrt import Input
4+
5+
from .harness import DispatchTestCase
76

87
from .harness import DispatchTestCase
98

@@ -16,7 +15,6 @@ def __init__(self):
1615
super().__init__()
1716

1817
def forward(self, x):
19-
index0 = torch.randint(0, 1, (1, 1))
2018
indices = [None, self.index0]
2119
out = torch.ops.aten.index.Tensor(x, indices)
2220
return out
@@ -159,8 +157,6 @@ def __init__(self):
159157
super().__init__()
160158

161159
def forward(self, x):
162-
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
163-
index1 = index0.unsqueeze(0).T.long()
164160
indices = [None, None, self.index0, self.index1]
165161
out = torch.ops.aten.index.Tensor(x, indices)
166162
return out

0 commit comments

Comments
 (0)