Skip to content

Commit 67ade93

Browse files
apbosenarendasan
authored andcommitted
Converter reorg elu
Adding selu converter Python linting correction
1 parent de8faa2 commit 67ade93

File tree

6 files changed

+219
-10
lines changed

6 files changed

+219
-10
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -1040,11 +1040,14 @@ def acc_ops_elu(
10401040
kwargs: Dict[str, Argument],
10411041
name: str,
10421042
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1043-
input_val = kwargs["input"]
1044-
alpha = kwargs["alpha"]
1045-
operation_type = trt.ActivationType.ELU
1046-
return activation.convert_activation(
1047-
network, target, SourceIR.ACC, name, operation_type, input_val, alpha=alpha
1043+
1044+
return activation.elu(
1045+
network,
1046+
target,
1047+
SourceIR.ACC,
1048+
name,
1049+
kwargs["input"],
1050+
kwargs["alpha"],
10481051
)
10491052

10501053

@@ -1056,15 +1059,13 @@ def acc_ops_selu(
10561059
kwargs: Dict[str, Argument],
10571060
name: str,
10581061
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1059-
input_val = kwargs["input"]
1060-
operation_type = trt.ActivationType.SELU
1061-
return activation.convert_activation(
1062+
1063+
return activation.selu(
10621064
network,
10631065
target,
10641066
SourceIR.ACC,
10651067
name,
1066-
operation_type,
1067-
input_val,
1068+
kwargs["input"],
10681069
)
10691070

10701071

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+27
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,33 @@ def aten_ops_div(
170170
)
171171

172172

173+
@tensorrt_converter(torch.ops.aten.elu.default)
174+
def aten_ops_elu(
175+
network: TRTNetwork,
176+
target: Target,
177+
args: Tuple[Argument, ...],
178+
kwargs: Dict[str, Argument],
179+
name: str,
180+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
181+
182+
if len(args) > 2:
183+
return activation.selu(
184+
network,
185+
target,
186+
SourceIR.ATEN,
187+
name,
188+
args[0],
189+
)
190+
return activation.elu(
191+
network,
192+
target,
193+
SourceIR.ATEN,
194+
name,
195+
args[0],
196+
args[1],
197+
)
198+
199+
173200
@tensorrt_converter(torch.ops.aten.floor_divide.default)
174201
def aten_ops_floor_div(
175202
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/activation.py

+48
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,51 @@ def leaky_relu_dyn_range_fn(dyn_range):
202202
alpha,
203203
dyn_range_fn=leaky_relu_dyn_range_fn,
204204
)
205+
206+
207+
def elu(
208+
network: TRTNetwork,
209+
target: Target,
210+
source_ir: Optional[SourceIR],
211+
name: str,
212+
input_val: TRTTensor,
213+
alpha: Optional[Any],
214+
):
215+
operation_type = trt.ActivationType.ELU
216+
217+
def elu_dyn_range_fn(dyn_range):
218+
return (torch.nn.ELU(dyn_range[0]), torch.nn.ELU(dyn_range[1]))
219+
220+
return convert_activation(
221+
network,
222+
target,
223+
source_ir,
224+
name,
225+
operation_type,
226+
input_val,
227+
alpha,
228+
dyn_range_fn=elu_dyn_range_fn,
229+
)
230+
231+
232+
def selu(
233+
network: TRTNetwork,
234+
target: Target,
235+
source_ir: Optional[SourceIR],
236+
name: str,
237+
input_val: TRTTensor,
238+
):
239+
operation_type = trt.ActivationType.SELU
240+
241+
def elu_dyn_range_fn(dyn_range):
242+
return (torch.nn.SELU(dyn_range[0]), torch.nn.ELU(dyn_range[1]))
243+
244+
return convert_activation(
245+
network,
246+
target,
247+
source_ir,
248+
name,
249+
operation_type,
250+
input_val,
251+
dyn_range_fn=elu_dyn_range_fn,
252+
)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

+31
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,34 @@ def leaky_relu(network, submod, args, kwargs, layer_name):
8282
input_val=kwargs["input"],
8383
alpha=kwargs["negative_slope"],
8484
)
85+
86+
87+
@tensorrt_converter(torch.nn.functional.elu)
88+
@tensorrt_converter(torch.nn.modules.activation.ELU)
89+
def elu(network, submod, args, kwargs, layer_name):
90+
# args/kwargs should have already been normalized to kwargs
91+
assert len(args) == 0
92+
93+
return activation.elu(
94+
network=network,
95+
target="torch.nn.functional.elu",
96+
source_ir=SourceIR.NN,
97+
name=layer_name,
98+
input_val=kwargs["input"],
99+
)
100+
101+
102+
@tensorrt_converter(torch.nn.functional.selu)
103+
@tensorrt_converter(torch.nn.modules.activation.SELU)
104+
def selu(network, submod, args, kwargs, layer_name):
105+
# args/kwargs should have already been normalized to kwargs
106+
assert len(args) == 0
107+
108+
return activation.selu(
109+
network=network,
110+
target="torch.nn.functional.selu",
111+
source_ir=SourceIR.NN,
112+
name=layer_name,
113+
input_val=kwargs["input"],
114+
alpha=kwargs["alpha"],
115+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestELUConverter(DispatchTestCase):
8+
def test_elu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.elu(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})
15+
16+
def test_elu_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.elu(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
30+
)
31+
32+
def test_elu_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.elu(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestSeLUConverter(DispatchTestCase):
8+
def test_selu(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.selu(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default})
15+
16+
def test_selu_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return nn.functional.selu(x)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
30+
)
31+
32+
def test_selu_with_dynamic_shape_four_dimensions(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return nn.functional.selu(x)
36+
37+
input_specs = [
38+
InputTensorSpec(
39+
shape=(-1, -1, -1, -1),
40+
dtype=torch.float32,
41+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
42+
),
43+
]
44+
45+
self.run_test_with_dynamic_shape(
46+
TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default}
47+
)
48+
49+
50+
if __name__ == "__main__":
51+
run_tests()

0 commit comments

Comments
 (0)