Skip to content

Commit 9ecae98

Browse files
committed
[dynamo/converter] support neg converter
1 parent d630a1e commit 9ecae98

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,27 @@ def aten_ops_rsqrt(
353353
args[0],
354354
)
355355

356+
@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
357+
def aten_ops_neg(
358+
network: TRTNetwork,
359+
target: Target,
360+
args: Tuple[Argument, ...],
361+
kwargs: Dict[str, Argument],
362+
name: str,
363+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
364+
input_val = args[0]
365+
if (isinstance(input_val, TRTTensor)) and (
366+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
367+
):
368+
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
369+
370+
return impl.unary.neg(
371+
network,
372+
target,
373+
SourceIR.ATEN,
374+
name,
375+
input_val,
376+
)
356377

357378
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) # type: ignore[misc]
358379
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) # type: ignore[misc]

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,15 @@ def isinf(
384384
return convert_unary(
385385
network, target, source_ir, name, trt.UnaryOperation.ISINF, input_val
386386
)
387+
388+
389+
def neg(
390+
network: TRTNetwork,
391+
target: Target,
392+
source_ir: Optional[SourceIR],
393+
name: str,
394+
input_val: TRTTensor,
395+
) -> TRTTensor:
396+
return convert_unary(
397+
network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
398+
)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
6+
from torch_tensorrt import Input
7+
8+
9+
class TestNegConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d_dim_dtype_float", (2, 2), torch.float),
13+
("3d_dim_dtype_float", (2, 2, 2), torch.float),
14+
("2d_dim_dtype_half", (2, 2), torch.half),
15+
("3d_dim_dtype_half", (2, 2, 2), torch.half),
16+
]
17+
)
18+
def test_neg_float(self, _, x, type):
19+
class neg(nn.Module):
20+
def forward(self, input):
21+
return torch.neg(input)
22+
23+
inputs = [torch.randn(x, dtype=type)]
24+
self.run_test(
25+
neg(),
26+
inputs,
27+
precision=type,
28+
expected_ops={torch.ops.aten.neg.default},
29+
)
30+
31+
@parameterized.expand(
32+
[
33+
("2d_dim_dtype_int32", (2, 2), torch.int32, 0, 5),
34+
("3d_dim_dtype_int32", (2, 2, 2), torch.int32, 0, 5),
35+
]
36+
)
37+
def test_neg_int(self, _, x, type, min, max):
38+
class neg(nn.Module):
39+
def forward(self, input):
40+
return torch.neg(input)
41+
42+
inputs = [torch.randint(min, max, x, dtype=type)]
43+
self.run_test(
44+
neg(),
45+
inputs,
46+
output_dtypes=[torch.int32],
47+
expected_ops={torch.ops.aten.neg.default},
48+
)
49+
50+
51+
if __name__ == "__main__":
52+
run_tests()

0 commit comments

Comments
 (0)