Skip to content

Commit 1717018

Browse files
committed
rebase
1 parent 732de3b commit 1717018

File tree

4 files changed

+25
-90
lines changed

4 files changed

+25
-90
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 10 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ def aten_ops_sigmoid(
138138
)
139139

140140

141-
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor)
141+
@dynamo_tensorrt_converter(torch.ops.aten.index.Tensor) # type: ignore[misc]
142142
@enforce_tensor_types(
143143
{
144144
0: (TRTTensor,),
145145
}
146-
)
146+
) # type: ignore[misc]
147147
def aten_ops_index(
148148
ctx: ConversionContext,
149149
target: Target,
@@ -723,27 +723,8 @@ def one_user_validator(node: Node) -> bool:
723723

724724

725725
@dynamo_tensorrt_converter(torch.ops.aten.max.default) # type: ignore[misc]
726-
def aten_ops_max(
727-
ctx: ConversionContext,
728-
target: Target,
729-
args: Tuple[Argument, ...],
730-
kwargs: Dict[str, Argument],
731-
name: str,
732-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
733-
return impl.reduce.max(
734-
ctx,
735-
target,
736-
SourceIR.ATEN,
737-
name,
738-
args[0],
739-
dim=None,
740-
keepdim=False,
741-
return_indices=False,
742-
)
743-
744-
745726
@dynamo_tensorrt_converter(torch.ops.aten.max.dim, capability_validator=one_user_validator) # type: ignore[misc]
746-
def aten_ops_max_dim(
727+
def aten_ops_max(
747728
ctx: ConversionContext,
748729
target: Target,
749730
args: Tuple[Argument, ...],
@@ -756,34 +737,15 @@ def aten_ops_max_dim(
756737
SourceIR.ATEN,
757738
name,
758739
args[0],
759-
args[1],
760-
args_bounds_check(args, 2, replacement=False),
761-
return_indices=True,
740+
dim=args_bounds_check(args, 1, replacement=None),
741+
keepdim=args_bounds_check(args, 2, replacement=False),
742+
return_indices=(target == torch.ops.aten.max.dim),
762743
)
763744

764745

765746
@dynamo_tensorrt_converter(torch.ops.aten.min.default) # type: ignore[misc]
766-
def aten_ops_min(
767-
ctx: ConversionContext,
768-
target: Target,
769-
args: Tuple[Argument, ...],
770-
kwargs: Dict[str, Argument],
771-
name: str,
772-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
773-
return impl.reduce.min(
774-
ctx,
775-
target,
776-
SourceIR.ATEN,
777-
name,
778-
args[0],
779-
dim=None,
780-
keepdim=False,
781-
return_indices=False,
782-
)
783-
784-
785747
@dynamo_tensorrt_converter(torch.ops.aten.min.dim, capability_validator=one_user_validator) # type: ignore[misc]
786-
def aten_ops_min_dim(
748+
def aten_ops_min(
787749
ctx: ConversionContext,
788750
target: Target,
789751
args: Tuple[Argument, ...],
@@ -796,9 +758,9 @@ def aten_ops_min_dim(
796758
SourceIR.ATEN,
797759
name,
798760
args[0],
799-
args[1],
800-
args_bounds_check(args, 2, replacement=False),
801-
return_indices=True,
761+
dim=args_bounds_check(args, 1, replacement=None),
762+
keepdim=args_bounds_check(args, 2, replacement=False),
763+
return_indices=(target == torch.ops.aten.min.dim),
802764
)
803765

804766

tests/py/dynamo/conversion/harness.py

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -197,35 +197,18 @@ def generate_graph(
197197
use_dynamo_tracer: bool,
198198
enable_passes: bool,
199199
):
200-
# Torchdynamo+aot proxytensor tracer
201-
# Below are common passes
202-
passes_list = [
203-
compose_bmm,
204-
compose_chunk,
205-
compose_getitem_slice,
206-
replace_aten_reshape_alias_with_replace,
207-
replace_aten_op_with_indices,
208-
replace_transpose_mm_op_with_linear, # after compose_bmm
209-
replace_native_layernorm_with_layernorm,
210-
remove_ops,
211-
replace_builtin_ops, # after replace_native_layernorm_with_layernorm
212-
]
213-
# Combine with customized passes specific to any model
214-
if customized_passes:
215-
passes_list.extend(customized_passes)
216-
217-
if disable_passes:
218-
passes_list = []
219-
220-
fx_module, _ = aten_tracer.trace(mod, original_inputs)
221-
for passes in passes_list:
222-
pr: PassResult = passes(fx_module)
223-
fx_module = pr.graph_module
224-
fx_module(*original_inputs)
225-
226-
fx_module = run_const_fold(fx_module)
227-
fx_module.graph.eliminate_dead_code()
228-
200+
if use_dynamo_tracer:
201+
fx_module = torch._dynamo.export(
202+
mod,
203+
*original_inputs,
204+
aten_graph=True,
205+
assume_static_by_default=True,
206+
tracing_mode="real",
207+
).graph_module
208+
else:
209+
fx_module = torch.fx.symbolic_trace(mod)
210+
if enable_passes:
211+
fx_module = apply_lowering_passes(fx_module, original_inputs)
229212
_LOGGER.info(f"FX graph= {fx_module.graph}")
230213
return fx_module
231214

tests/py/dynamo/conversion/test_max_aten.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@ class TestMaxConverter(DispatchTestCase):
1818
def test_max_dim_int_default(self, input_shape):
1919
class Max(nn.Module):
2020
def forward(self, x):
21-
return torch.max(x)
21+
return torch.ops.aten.max.default(x)
2222

2323
inputs = [torch.randn(*input_shape)]
2424
self.run_test(
2525
Max(),
2626
inputs,
27-
expected_ops={torch.ops.aten.max.default},
2827
)
2928

3029
@parameterized.expand(
@@ -45,8 +44,6 @@ def forward(self, x):
4544
self.run_test(
4645
Max(),
4746
inputs,
48-
expected_ops={torch.ops.aten.max.dim},
49-
disable_passes=True,
5047
)
5148

5249
@parameterized.expand(
@@ -65,9 +62,7 @@ def forward(self, x):
6562
self.run_test(
6663
Max(),
6764
inputs,
68-
expected_ops={torch.ops.aten.max.dim},
6965
check_dtype=False,
70-
disable_passes=True,
7166
)
7267

7368

tests/py/dynamo/conversion/test_min_aten.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def forward(self, x):
2424
self.run_test(
2525
Min(),
2626
inputs,
27-
expected_ops={torch.ops.aten.min.default},
2827
)
2928

3029
@parameterized.expand(
@@ -39,14 +38,12 @@ def forward(self, x):
3938
def test_min_dim_int(self, input_shape, dim, keep_dims):
4039
class Min(nn.Module):
4140
def forward(self, x):
42-
return torch.min(x, dim=dim, keepdim=keep_dims)[0]
41+
return torch.ops.aten.min.dim(x, dim, keep_dims)[0]
4342

4443
inputs = [torch.randn(*input_shape)]
4544
self.run_test(
4645
Min(),
4746
inputs,
48-
expected_ops={torch.ops.aten.min.dim},
49-
disable_passes=True,
5047
)
5148

5249
@parameterized.expand(
@@ -59,15 +56,13 @@ def forward(self, x):
5956
def test_min_dim_int_int(self, input_shape, dim, keep_dims, dtype, low, high):
6057
class Min(nn.Module):
6158
def forward(self, x):
62-
return torch.ops.aten.min.dim(x, dim=dim, keepdim=keep_dims)[0]
59+
return torch.ops.aten.min.dim(x, dim, keep_dims)[0]
6360

6461
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
6562
self.run_test(
6663
Min(),
6764
inputs,
68-
expected_ops={torch.ops.aten.min.dim},
6965
check_dtype=False,
70-
disable_passes=True,
7166
)
7267

7368

0 commit comments

Comments
 (0)