Skip to content

Commit 2a26c98

Browse files
committed
fix: Add truncate_long_and_double to Dynamo
- Add default, setting, and function arguments for `truncate_long_and_double` in Dynamo - Add utilities for repairing long/double inputs to TRT engines, including support for autocasting back to long/double after the computation completes - Add multiple helper functions to enable easy testing and diagnosis of long/double IO to TRT engines - Add necessary compiler code to enable usage of the `truncate_long_and_double` argument as a switch for the feature - Add Dynamo compile support for `truncate_long_and_double` compilation argument by intercepting long/double type inputs and casting them to their 32-bit counterparts prior to usage in TRT-accelerated subgraphs, then casting back if necessary - Add robust logic to handle 64-bit inputs and outputs - Add test cases for long and double scenarios - Centralize truncation utility for later use in Dynamo export path
1 parent 5377e43 commit 2a26c98

File tree

7 files changed

+341
-2
lines changed

7 files changed

+341
-2
lines changed

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
VERSION_COMPATIBLE = False
1010
OPTIMIZATION_LEVEL = None
1111
USE_PYTHON_RUNTIME = None
12+
TRUNCATE_LONG_AND_DOUBLE = False

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
VERSION_COMPATIBLE,
1212
OPTIMIZATION_LEVEL,
1313
USE_PYTHON_RUNTIME,
14+
TRUNCATE_LONG_AND_DOUBLE,
1415
)
1516

1617

@@ -26,3 +27,4 @@ class CompilationSettings:
2627
version_compatible: bool = VERSION_COMPATIBLE
2728
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
2829
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
30+
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE

py/torch_tensorrt/dynamo/backend/backends.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
get_submod_inputs,
1717
)
1818
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
19-
from torch_tensorrt.dynamo.conversion import convert_module
19+
from torch_tensorrt.dynamo.conversion import (
20+
convert_module,
21+
repair_long_or_double_inputs,
22+
)
2023

2124
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2225

@@ -135,6 +138,16 @@ def _compile_module(
135138
partitioned_module, submodule, sample_inputs
136139
)
137140

141+
# Ensure all submodule inputs do not require a gradient
142+
for param in submodule_inputs:
143+
param.requires_grad = False
144+
145+
# Handle long/double inputs if requested by the user
146+
if settings.truncate_long_and_double:
147+
submodule_inputs = repair_long_or_double_inputs(
148+
partitioned_module, submodule, submodule_inputs, name
149+
)
150+
138151
# Create TRT Module from submodule
139152
trt_mod = convert_module(
140153
submodule,

py/torch_tensorrt/dynamo/compile.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
VERSION_COMPATIBLE,
3131
OPTIMIZATION_LEVEL,
3232
USE_PYTHON_RUNTIME,
33+
TRUNCATE_LONG_AND_DOUBLE,
3334
)
3435

3536

@@ -53,7 +54,7 @@ def compile(
5354
dla_local_dram_size=1073741824,
5455
dla_global_dram_size=536870912,
5556
calibrator=None,
56-
truncate_long_and_double=False,
57+
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
5758
require_full_compilation=False,
5859
min_block_size=MIN_BLOCK_SIZE,
5960
torch_executed_ops=[],
@@ -109,6 +110,7 @@ def compile(
109110
"version_compatible": version_compatible,
110111
"optimization_level": optimization_level,
111112
"use_python_runtime": use_python_runtime,
113+
"truncate_long_and_double": truncate_long_and_double,
112114
}
113115

114116
settings = CompilationSettings(**compilation_options)
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .trt_interpreter import *
22
from .conversion import *
3+
from .truncate_long_and_double import repair_long_or_double_inputs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import torch
2+
from torch.fx.node import _get_qualified_name
3+
from typing import Optional, Sequence, Union
4+
5+
6+
def _extract_downstream_get_nodes(
7+
module_node: torch.fx.Node, output_indices: Sequence[int]
8+
) -> Sequence[torch.fx.Node]:
9+
"""Extracts downstream users of a node which get the item at a particular index
10+
11+
Certain module-type nodes have multiple outputs (tuple of outputs). This function
12+
returns downstream nodes which call the _operator.getitem function, which extracts
13+
the element at a particular index in the tuple
14+
15+
Args:
16+
module_node: FX module-type node to analyze
17+
output_index: Indices in the module node output to search for
18+
Returns:
19+
List of nodes which get the item at the specified index in the module node output
20+
"""
21+
get_nodes = []
22+
23+
# Iterate over all downstream users of the node object
24+
for user in module_node.users:
25+
# If the user is a "get" node accessing the specified index, store it
26+
if _get_qualified_name(user.target) == "_operator.getitem" and (
27+
user.args[1] in output_indices
28+
):
29+
get_nodes.append(user)
30+
31+
return get_nodes
32+
33+
34+
def _repair_64bit_input(
35+
gm: torch.fx.GraphModule,
36+
position: int,
37+
submodule_name: str,
38+
submodule_outputs: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
39+
dtype: torch.dtype,
40+
):
41+
"""Fixes a single Long/Double input to a TRT-accelerated subgraph
42+
43+
In-Place modifies the provided graph
44+
45+
Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
46+
inserts an upcast back to the 64-bit type for subsequent Torch operations
47+
48+
Args:
49+
gm: FX GraphModule enclosing the TRT subgraph
50+
position: Index in the submodule inputs at which the long or double input is found
51+
submodule_name: Name of TRT-accelerated subgraph module in FX graph
52+
submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
53+
dtype: Data type of tensor at position in submodule (double/long)
54+
"""
55+
assert dtype in (
56+
torch.int64,
57+
torch.float64,
58+
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"
59+
60+
# Determine target data type in 32 and 64 bit forms
61+
dtype_64bit = dtype
62+
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32
63+
64+
# Find the node representing the submodule in the graph
65+
module_node = None
66+
67+
# Iterate over all nodes in the graph, seeking target module name match
68+
for n in gm.graph.nodes:
69+
if n.op == "call_module" and str(n.target) == submodule_name:
70+
module_node = n
71+
break
72+
73+
if module_node is None:
74+
raise AssertionError(
75+
f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}"
76+
)
77+
78+
# Extract the 64-bit node of the input
79+
node_64bit = module_node.all_input_nodes[position]
80+
81+
# Prior to the module, insert a cast to the 32-bit equivalent node
82+
with gm.graph.inserting_before(module_node):
83+
node_32bit = gm.graph.call_function(
84+
torch.ops.aten._to_copy.default,
85+
args=(node_64bit,),
86+
kwargs={"dtype": dtype_32bit},
87+
)
88+
89+
# Replace 64-bit input to TRT module with new 32-bit cast node
90+
module_node.replace_input_with(node_64bit, node_32bit)
91+
92+
output_positions_64bit = set()
93+
outputs_list = (
94+
[submodule_outputs]
95+
if isinstance(submodule_outputs, torch.Tensor)
96+
else submodule_outputs
97+
)
98+
99+
# Determine if any outputs of the model are 64-bit type and store their indices
100+
if submodule_outputs is not None:
101+
for output_position, output in enumerate(outputs_list):
102+
if output.dtype == dtype_64bit:
103+
output_positions_64bit.add(output_position)
104+
105+
# Only enter this code block if there exists a 64-bit output
106+
# This implies a cast is needed, since TRT cannot output 64-bit tensors
107+
if output_positions_64bit:
108+
# Determine whther the outputs of the module are tuple-type or not
109+
is_collection_output = False
110+
if isinstance(submodule_outputs, tuple):
111+
is_collection_output = True
112+
113+
if not is_collection_output:
114+
# If the output is a single tensor, insert a cast back to int64
115+
with gm.graph.inserting_after(module_node):
116+
cast_node_64bit = gm.graph.call_function(
117+
torch.ops.aten._to_copy.default,
118+
args=(module_node,),
119+
kwargs={"dtype": dtype_64bit},
120+
)
121+
122+
# Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent
123+
module_node.replace_all_uses_with(
124+
cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit)
125+
)
126+
127+
else:
128+
# If the output is a tuple of tensors, extract downstream users for each 64-bit output
129+
get_nodes = _extract_downstream_get_nodes(
130+
module_node, output_positions_64bit
131+
)
132+
133+
# For each downstream user, append a cast node back to the 64-bit precision
134+
for get_node in get_nodes:
135+
with gm.graph.inserting_after(get_node):
136+
cast_node_64bit = gm.graph.call_function(
137+
torch.ops.aten._to_copy.default,
138+
args=(get_node,),
139+
kwargs={"dtype": torch.int64},
140+
)
141+
142+
get_node.replace_all_uses_with(
143+
cast_node_64bit,
144+
delete_user_cb=lambda user: (user != cast_node_64bit),
145+
)
146+
147+
# Clean up graph and ensure invariants are preserved
148+
gm.graph.eliminate_dead_code()
149+
gm.graph.lint()
150+
gm.recompile()
151+
152+
153+
def repair_long_or_double_inputs(
154+
parent_graph: torch.fx.GraphModule,
155+
submodule: torch.fx.GraphModule,
156+
submodule_inputs: Sequence[torch.Tensor],
157+
submodule_name: Optional[str] = None,
158+
) -> Sequence[torch.Tensor]:
159+
"""Fixes all Long/Double type inputs to a TRT-accelerated subgraph
160+
161+
In-Place modifies the provided graph
162+
163+
Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
164+
inserts an upcast back to the 64-bit type for subsequent Torch operations
165+
166+
Args:
167+
parent_graph: FX GraphModule enclosing the TRT subgraph
168+
submodule: Child submodule to repair inputs on
169+
submodule_inputs: Input tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
170+
submodule_name: Optionally specify the name of the submodule target in the parent graph
171+
Returns:
172+
New submodule inputs, updated accordingly with long/double truncation
173+
"""
174+
num_submodule_inputs = len(submodule_inputs)
175+
repaired_outputs_once = False
176+
177+
# For each input to the TRT subgraph, check if its type is long/double
178+
for position in range(num_submodule_inputs):
179+
param = submodule_inputs[position]
180+
181+
# If the data type of the input is long/double, insert necessary
182+
# casts to replace the operation
183+
if param.dtype in (torch.int64, torch.float64):
184+
# Ensure outputs are only repaired once per submodule to avoid
185+
# unnecessary ops showing up in the graph
186+
if not repaired_outputs_once:
187+
submodule_outputs = submodule(*submodule_inputs)
188+
189+
_repair_64bit_input(
190+
parent_graph,
191+
position,
192+
submodule_name if submodule_name is not None else submodule._get_name(),
193+
None if repaired_outputs_once else submodule_outputs,
194+
param.dtype,
195+
)
196+
197+
repaired_outputs_once = True
198+
199+
# Repair submodule inputs in accordance with inserted casts
200+
dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32
201+
submodule_inputs = (
202+
submodule_inputs[:position]
203+
+ (param.to(dtype_32bit),)
204+
+ submodule_inputs[position + 1 :]
205+
)
206+
207+
return submodule_inputs

0 commit comments

Comments
 (0)