Skip to content

Commit 892f7d0

Browse files
committed
fix: Add special cases where input of graph is output
- TRT does not allow inputs of graphs to be outputs as well, however many of the scenarios encountered in real models can have this situation come up, especially in cases where the input is cloned or copied and then returned - The current converters will register these operators as a no-op, causing TRT engine building to fail on such inputs - Instead of requiring creation of an identity layer for every case of a clone or copy node, we instead check if that node is the only operator on a placeholder (input) and then insert the identity layer or not, accordingly - Coalesce implementations of clone and to_copy, which are effectively the same operator - Add test cases to validate new behavior - Add new boilerplate converter validator utility to support this case
1 parent 8ebb599 commit 892f7d0

File tree

5 files changed

+153
-67
lines changed

5 files changed

+153
-67
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import logging
2-
from typing import Any, Dict, Optional, Sequence, Tuple, Union
2+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
33

44
import torch
55
from torch.fx.node import Argument, Node, Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
77
from torch_tensorrt.dynamo.conversion import impl
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
is_only_operator_on_placeholder,
10+
)
811
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
912

1013
from .converter_registry import dynamo_tensorrt_converter
@@ -441,29 +444,59 @@ def aten_ops_permute(
441444
)
442445

443446

444-
def to_copy_dtype_validator(to_copy_node: Node) -> bool:
445-
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}
446-
447-
# Validate input node has convertible kwargs
448-
if "dtype" in to_copy_node.kwargs:
449-
if to_copy_node.kwargs["dtype"] in allowed_casts:
450-
return True
447+
def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]:
448+
"""Return validator for to_copy node with placeholder restrictions"""
449+
450+
def validate_dtype(to_copy_node: Node) -> bool:
451+
"""Returns true if the to_copy node can be converted to TRT
452+
453+
Based on data type being casted to
454+
"""
455+
allowed_casts = {
456+
torch.float,
457+
torch.int32,
458+
torch.bool,
459+
torch.int8,
460+
torch.float16,
461+
}
462+
463+
# Validate input node has convertible kwargs
464+
if "dtype" in to_copy_node.kwargs:
465+
if to_copy_node.kwargs["dtype"] in allowed_casts:
466+
return True
467+
else:
468+
_LOGGER.debug(
469+
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
470+
)
471+
return False
451472
else:
452473
_LOGGER.debug(
453-
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
474+
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
454475
)
455476
return False
456-
else:
457-
_LOGGER.debug(
458-
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
477+
478+
def validator(to_copy_node: Node) -> bool:
479+
"""Returns true if the to_copy node can be converted to TRT
480+
and the placeholder restriction is satisfied
481+
"""
482+
# The placeholder restriction is satsfied if placeholder_only is the same
483+
# truth value as is_only_operator_on_placeholder(to_copy_node)
484+
return validate_dtype(to_copy_node) and (
485+
(not placeholder_only) ^ is_only_operator_on_placeholder(to_copy_node)
459486
)
460-
return False
487+
488+
return validator
461489

462490

463491
@dynamo_tensorrt_converter(
464-
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
492+
torch.ops.aten.clone.default,
493+
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
465494
) # type: ignore[misc]
466-
def aten_ops_to_copy_dtype(
495+
@dynamo_tensorrt_converter(
496+
torch.ops.aten._to_copy.default,
497+
capability_validator=to_copy_dtype_validator(placeholder_only=False),
498+
) # type: ignore[misc]
499+
def aten_ops_clone_copy_dtype(
467500
network: TRTNetwork,
468501
target: Target,
469502
args: Tuple[Argument, ...],
@@ -476,24 +509,37 @@ def aten_ops_to_copy_dtype(
476509
SourceIR.ATEN,
477510
name,
478511
args[0],
479-
kwargs["dtype"],
512+
kwargs.get("dtype", args[0].dtype),
513+
force_layer=False,
480514
)
481515

482516

483-
@dynamo_tensorrt_converter(torch.ops.aten.clone.default) # type: ignore[misc]
484-
def aten_ops_clone(
517+
@dynamo_tensorrt_converter(
518+
torch.ops.aten.clone.default,
519+
capability_validator=is_only_operator_on_placeholder,
520+
) # type: ignore[misc]
521+
@dynamo_tensorrt_converter(
522+
torch.ops.aten._to_copy.default,
523+
capability_validator=to_copy_dtype_validator(placeholder_only=True),
524+
) # type: ignore[misc]
525+
def aten_ops_clone_copy_placeholder(
485526
network: TRTNetwork,
486527
target: Target,
487528
args: Tuple[Argument, ...],
488529
kwargs: Dict[str, Argument],
489530
name: str,
490531
) -> Union[TRTTensor, Sequence[TRTTensor]]:
491-
return impl.cast.clone(
532+
# For clone or copy nodes where the input is also the output,
533+
# we need to force cast to ensure a layer is added to the TRT engine
534+
# since TRT engine inputs cannot also be TRT engine outputs
535+
return impl.cast.to_copy(
492536
network,
493537
target,
494538
SourceIR.ATEN,
495539
name,
496540
args[0],
541+
kwargs.get("dtype", args[0].dtype),
542+
force_layer=True,
497543
)
498544

499545

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,49 @@ def get_node_name(node: torch.fx.Node) -> str:
4545
return node_name
4646

4747

48+
def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool:
49+
"""Detects whether a call_function node is the only operator on a placeholder"""
50+
# Returns true if the node operates on a placeholder and is a direct output
51+
return (
52+
node.op == "call_function"
53+
and any(
54+
arg.op == "placeholder"
55+
for arg in node.args
56+
if isinstance(arg, torch.fx.Node)
57+
)
58+
and any(user.op == "output" for user in list(node.users.keys()))
59+
)
60+
61+
4862
def dynamic_unsupported(node: torch.fx.Node) -> bool:
4963
# Validate that none of the inputs to the node have Dynamic shapes
5064
assert isinstance(
5165
node, torch.fx.Node
5266
), "Inputs to validator functions must be FX Nodes"
5367

5468
# Check node value itself
55-
if getattr(node.meta["val"], "_has_symbolic_sizes_strides", False):
69+
if ("val" in node.meta) and getattr(
70+
node.meta["val"], "_has_symbolic_sizes_strides", False
71+
):
5672
return False
5773

5874
# Check node arguments individually
5975
if any(
60-
getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
76+
(
77+
("val" in arg.meta)
78+
and getattr(arg.meta["val"], "_has_symbolic_sizes_strides", False)
79+
)
6180
for arg in node.args
6281
if isinstance(arg, torch.fx.Node)
6382
):
6483
return False
6584

6685
# Check node keyword arguments individually
6786
if any(
68-
getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
87+
(
88+
("val" in kwarg.meta)
89+
and getattr(kwarg.meta["val"], "_has_symbolic_sizes_strides", False)
90+
)
6991
for kwarg in node.kwargs.values()
7092
if isinstance(kwarg, torch.fx.Node)
7193
):
@@ -82,9 +104,12 @@ def cast_trt_tensor(
82104
target: Target = "",
83105
source_ir: Optional[SourceIR] = None,
84106
) -> TRTTensor:
85-
"""
86-
Given a TRT Tensor, convert that Tensor to the specified dtype
107+
"""Given a TRT Tensor, convert that Tensor to the specified dtype
108+
87109
Adds an Identity layer to the network which performs the conversion
110+
if the input's dtype is different from the cast type. Otherwise returns
111+
input unchanged
112+
88113
Args:
89114
network (TRTNetwork): A TensorRT network
90115
input_val (TRTTensor): A TRT Tensor to cast to a new data type
@@ -191,7 +216,7 @@ def extend_attr_to_tuple(
191216
if isinstance(val, tuple):
192217
return val
193218
else:
194-
raise AssertionError(f"Could not extend attribute {val}")
219+
raise AssertionError(f"Object {val} could not be extended to tuple")
195220

196221

197222
def cast_int_or_float_to_bool(

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
from torch.fx.node import Target
55
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion.converter_registry import ConverterRegistry
67
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
8+
from torch_tensorrt.fx.converters.converter_utils import (
9+
Frameworks,
10+
unified_dtype_converter,
11+
)
712
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor
813

914
LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -16,28 +21,25 @@ def to_copy(
1621
name: str,
1722
input: TRTTensor,
1823
dtype: TRTDataType,
24+
force_layer: bool = False,
1925
) -> TRTTensor:
2026
if not isinstance(input, TRTTensor):
2127
raise RuntimeError(
2228
f"to_copy received input {input} that is not a TensorRT ITensor"
2329
)
2430

25-
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
26-
return casted_tensor
27-
28-
29-
def clone(
30-
network: TRTNetwork,
31-
target: Target,
32-
source_ir: Optional[SourceIR],
33-
name: str,
34-
input: TRTTensor,
35-
) -> TRTTensor:
36-
if not isinstance(input, TRTTensor):
37-
raise RuntimeError(
38-
f"clone received input {input} that is not a TensorRT ITensor"
39-
)
40-
41-
LOGGER.debug(f"Evaluating clone on object with name: {name}")
42-
43-
return input
31+
# If cast is forced, insert identity layer regardless of whether the dtype
32+
# doesn't change
33+
if force_layer:
34+
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)
35+
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
36+
target_str = ConverterRegistry.qualified_name_or_str(target)
37+
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"
38+
39+
identity_layer = network.add_identity(input)
40+
identity_layer.set_output_type(0, trt_dtype)
41+
identity_layer.name = f"Forced Cast ITensor {input.name} from {input.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
42+
return identity_layer.get_output(0)
43+
else:
44+
casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
45+
return casted_tensor

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
cast_trt_tensor,
1212
get_trt_tensor,
1313
)
14-
from torch_tensorrt.fx.converters.converter_utils import (
15-
broadcast,
16-
set_layer_name,
17-
squeeze_left,
18-
)
14+
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
1915
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
2016
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2117

@@ -96,10 +92,10 @@ def convert_binary_elementwise(
9692
is_rhs_trt_tensor = False
9793

9894
if isinstance(lhs_val, TRTTensor):
99-
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY)
95+
lhs_dtype = lhs_val.dtype
10096
is_lhs_trt_tensor = True
10197
if isinstance(rhs_val, TRTTensor):
102-
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY)
98+
rhs_dtype = rhs_val.dtype
10399
is_rhs_trt_tensor = True
104100

105101
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
@@ -124,23 +120,13 @@ def convert_binary_elementwise(
124120
# dtype but we don't have a way to detect whether it makes sense for the
125121
# scalar to be float or half. Hence we go with the lhs dtype.
126122
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
127-
rhs_val = np.array([rhs_val], dtype=lhs_dtype)
123+
rhs_val = np.array(
124+
[rhs_val], dtype=unified_dtype_converter(lhs_dtype, Frameworks.NUMPY)
125+
)
128126
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
129-
lhs_val = np.array([lhs_val], dtype=rhs_dtype)
130-
131-
# When lhs is scalar, and rhs has shape [1,], then currently the assert
132-
# will fail because lhs shape has fewer dimensions than rhs shape. This
133-
# happens when using implicit batch dimension, when we removed the 1st
134-
# dimension from input tensor, causing it to have shape [] - a scalar. We
135-
# fix it by reducing the rhs constant with a squeeze_left, so it becomes a
136-
# scalar too. More generally, we squeeze_left on input if it's a constant
137-
# tensor. This is safe because broadcast will pad dimensions on the left
138-
# (prepend) to make lhs and rhs shape compatible.
139-
if network.has_implicit_batch_dimension:
140-
if isinstance(lhs_val, torch.Tensor):
141-
lhs_val = squeeze_left(lhs_val)
142-
if isinstance(rhs_val, torch.Tensor):
143-
rhs_val = squeeze_left(rhs_val)
127+
lhs_val = np.array(
128+
[lhs_val], dtype=unified_dtype_converter(rhs_dtype, Frameworks.NUMPY)
129+
)
144130

145131
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
146132
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)

tests/py/dynamo/conversion/test_casts.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def forward(self, x):
3535
disable_passes=True,
3636
)
3737

38+
def test_clone_direct(self):
39+
class Clone(nn.Module):
40+
def forward(self, x):
41+
return x.clone()
42+
43+
inputs = [torch.randn((8, 2, 10))]
44+
self.run_test(
45+
Clone(),
46+
inputs,
47+
expected_ops={torch.ops.aten.clone.default},
48+
disable_passes=True,
49+
)
50+
3851

3952
class TestToCopyConverter(DispatchTestCase):
4053
def test_to_copy_half(self):
@@ -83,6 +96,20 @@ def forward(self, x):
8396
disable_passes=True,
8497
)
8598

99+
def test_to_copy_direct(self):
100+
class ToCopyFloat(nn.Module):
101+
def forward(self, x):
102+
return x.to(dtype=torch.float, copy=True)
103+
104+
inputs = [torch.rand((1, 3, 10)).float()]
105+
self.run_test(
106+
ToCopyFloat(),
107+
inputs,
108+
expected_ops={torch.ops.aten._to_copy.default},
109+
precision=torch.float,
110+
disable_passes=True,
111+
)
112+
86113

87114
if __name__ == "__main__":
88115
run_tests()

0 commit comments

Comments
 (0)