Skip to content

Commit 1fec519

Browse files
authored
add initial support for torch.ops.aten.neg.default converter (#2147)
1 parent bb5bf00 commit 1fec519

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+23
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,29 @@ def aten_ops_rsqrt(
354354
)
355355

356356

357+
@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
358+
def aten_ops_neg(
359+
network: TRTNetwork,
360+
target: Target,
361+
args: Tuple[Argument, ...],
362+
kwargs: Dict[str, Argument],
363+
name: str,
364+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
365+
input_val = args[0]
366+
if (isinstance(input_val, TRTTensor)) and (
367+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
368+
):
369+
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
370+
371+
return impl.unary.neg(
372+
network,
373+
target,
374+
SourceIR.ATEN,
375+
name,
376+
input_val,
377+
)
378+
379+
357380
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) # type: ignore[misc]
358381
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) # type: ignore[misc]
359382
def aten_ops_squeeze(

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

+12
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,15 @@ def isinf(
384384
return convert_unary(
385385
network, target, source_ir, name, trt.UnaryOperation.ISINF, input_val
386386
)
387+
388+
389+
def neg(
390+
network: TRTNetwork,
391+
target: Target,
392+
source_ir: Optional[SourceIR],
393+
name: str,
394+
input_val: TRTTensor,
395+
) -> TRTTensor:
396+
return convert_unary(
397+
network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
398+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
from torch_tensorrt.dynamo.test_utils import DispatchTestCase
7+
8+
9+
class TestNegConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("2d_dim_dtype_float", (2, 2), torch.float),
13+
("3d_dim_dtype_float", (2, 2, 2), torch.float),
14+
("2d_dim_dtype_half", (2, 2), torch.half),
15+
("3d_dim_dtype_half", (2, 2, 2), torch.half),
16+
]
17+
)
18+
def test_neg_float(self, _, x, type):
19+
class neg(nn.Module):
20+
def forward(self, input):
21+
return torch.neg(input)
22+
23+
inputs = [torch.randn(x, dtype=type)]
24+
self.run_test(
25+
neg(),
26+
inputs,
27+
precision=type,
28+
expected_ops={torch.ops.aten.neg.default},
29+
)
30+
31+
@parameterized.expand(
32+
[
33+
("2d_dim_dtype_int32", (2, 2), torch.int32, 0, 5),
34+
("3d_dim_dtype_int32", (2, 2, 2), torch.int32, 0, 5),
35+
]
36+
)
37+
def test_neg_int(self, _, x, type, min, max):
38+
class neg(nn.Module):
39+
def forward(self, input):
40+
return torch.neg(input)
41+
42+
inputs = [torch.randint(min, max, x, dtype=type)]
43+
self.run_test(
44+
neg(),
45+
inputs,
46+
output_dtypes=[torch.int32],
47+
expected_ops={torch.ops.aten.neg.default},
48+
)
49+
50+
51+
if __name__ == "__main__":
52+
run_tests()

0 commit comments

Comments
 (0)