Skip to content

Commit 387c2f9

Browse files
authored
Reorg for converters in hardtanh(FX Converter Refactor [5/N]) <Target: converter_reorg_proto> (#1901)
1 parent 6d28bba commit 387c2f9

File tree

5 files changed

+117
-12
lines changed

5 files changed

+117
-12
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3585,23 +3585,15 @@ def acc_ops_hardtanh(
35853585
kwargs: Dict[str, Argument],
35863586
name: str,
35873587
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3588-
input_val = kwargs["input"]
35893588

3590-
if not isinstance(input_val, TRTTensor):
3591-
raise RuntimeError(
3592-
f"hardtanh received input {input_val} that is not part "
3593-
"of the TensorRT region!"
3594-
)
3595-
3596-
return activation.convert_activation(
3589+
return activation.hardtanh(
35973590
network,
35983591
target,
35993592
SourceIR.ACC,
36003593
name,
3601-
trt.ActivationType.CLIP,
3602-
input_val,
3603-
alpha=kwargs["min_val"],
3604-
beta=kwargs["max_val"],
3594+
kwargs["input"],
3595+
kwargs["min_val"],
3596+
kwargs["max_val"],
36053597
)
36063598

36073599

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,20 @@ def aten_ops_fmod(
201201
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
202202

203203

204+
@tensorrt_converter(torch.ops.aten.hardtanh.default)
205+
def aten_ops_hardtanh(
206+
network: TRTNetwork,
207+
target: Target,
208+
args: Tuple[Argument, ...],
209+
kwargs: Dict[str, Argument],
210+
name: str,
211+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
212+
213+
return activation.hardtanh(
214+
network, target, SourceIR.ATEN, name, args[0], args[1], args[2]
215+
)
216+
217+
204218
@tensorrt_converter(torch.ops.aten.linear)
205219
def aten_ops_linear(
206220
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/activation.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,37 @@ def convert_activation(
6969
return layer.get_output(0)
7070

7171

72+
def hardtanh(
73+
network: TRTNetwork,
74+
target: Target,
75+
source_ir: Optional[SourceIR],
76+
name: str,
77+
input_val: TRTTensor,
78+
alpha: Optional[Any] = None,
79+
beta: Optional[Any] = None,
80+
):
81+
operation_type = trt.ActivationType.CLIP
82+
83+
def hardtanh_dyn_range_fn(dyn_range):
84+
def hardtanh_fn(x):
85+
# TODO: Called torch.nn.functional.hardtanh
86+
return torch.nn.functional.hardtanh(x)
87+
88+
return hardtanh_dyn_range_fn(dyn_range[0]), hardtanh_dyn_range_fn(dyn_range[1])
89+
90+
return convert_activation(
91+
network,
92+
target,
93+
source_ir,
94+
name,
95+
operation_type,
96+
input_val,
97+
alpha,
98+
beta,
99+
dyn_range_fn=hardtanh_dyn_range_fn,
100+
)
101+
102+
72103
def relu(
73104
network: TRTNetwork,
74105
target: Target,

py/torch_tensorrt/fx/converters/nn_ops_converters.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,18 @@ def sigmoid(network, submod, args, kwargs, layer_name):
3636
name=layer_name,
3737
input_val=kwargs["input"],
3838
)
39+
40+
41+
@tensorrt_converter(torch.nn.functional.hardtanh)
42+
@tensorrt_converter(torch.nn.modules.activation.Hardtanh)
43+
def hardtanh(network, submod, args, kwargs, layer_name):
44+
# args/kwargs should have already been normalized to kwargs
45+
assert len(args) == 0
46+
47+
return activation.hardtanh(
48+
network=network,
49+
target="torch.nn.modules.activation.Hardtanh",
50+
source_ir=SourceIR.NN,
51+
name=layer_name,
52+
input_val=kwargs["input"],
53+
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestHardTanHConverter(DispatchTestCase):
8+
def test_hardtanh(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.hardtanh(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(
15+
TestModule(), inputs, expected_ops={torch.ops.aten.hardtanh.default}
16+
)
17+
18+
def test_hardtanh_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x):
21+
return nn.functional.hardtanh(x)
22+
23+
input_specs = [
24+
InputTensorSpec(
25+
shape=(-1, -1, -1),
26+
dtype=torch.float32,
27+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
28+
),
29+
]
30+
self.run_test_with_dynamic_shape(
31+
TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default}
32+
)
33+
34+
def test_hardtanh_with_dynamic_shape_four_dimensions(self):
35+
class TestModule(nn.Module):
36+
def forward(self, x):
37+
return nn.functional.hardtanh(x)
38+
39+
input_specs = [
40+
InputTensorSpec(
41+
shape=(-1, -1, -1, -1),
42+
dtype=torch.float32,
43+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
44+
),
45+
]
46+
47+
self.run_test_with_dynamic_shape(
48+
TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default}
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

0 commit comments

Comments
 (0)