Skip to content

Commit f228eef

Browse files
committed
fix: Add truncate_long_and_double to Dynamo
- Add default, setting, and function arguments for `truncate_long_and_double` in Dynamo - Add utilities for repairing long/double inputs to TRT engines, including support for autocasting back to long/double after the computation completes - Add multiple helper functions to enable easy testing and diagnosis of long/double IO to TRT engines - Add necessary compiler code to enable usage of the `truncate_long_and_double` argument as a switch for the feature - Add Dynamo compile support for `truncate_long_and_double` compilation argument by intercepting long/double type inputs and casting them to their 32-bit counterparts prior to usage in TRT-accelerated subgraphs, then casting back if necessary - Add robust logic to handle 64-bit inputs and outputs - Add test cases for long and double scenarios - Centralize truncation utility for later use in Dynamo export path
1 parent 4d14ae8 commit f228eef

File tree

7 files changed

+337
-2
lines changed

7 files changed

+337
-2
lines changed

py/torch_tensorrt/dynamo/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
VERSION_COMPATIBLE = False
1010
OPTIMIZATION_LEVEL = None
1111
USE_PYTHON_RUNTIME = None
12+
TRUNCATE_LONG_AND_DOUBLE = False

py/torch_tensorrt/dynamo/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
VERSION_COMPATIBLE,
1212
OPTIMIZATION_LEVEL,
1313
USE_PYTHON_RUNTIME,
14+
TRUNCATE_LONG_AND_DOUBLE,
1415
)
1516

1617

@@ -26,3 +27,4 @@ class CompilationSettings:
2627
version_compatible: bool = VERSION_COMPATIBLE
2728
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
2829
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
30+
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE

py/torch_tensorrt/dynamo/backend/backends.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
get_submod_inputs,
1717
)
1818
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
19-
from torch_tensorrt.dynamo.conversion import convert_module
19+
from torch_tensorrt.dynamo.conversion import (
20+
convert_module,
21+
repair_long_or_double_inputs,
22+
)
2023

2124
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2225

@@ -135,6 +138,12 @@ def _compile_module(
135138
partitioned_module, submodule, sample_inputs
136139
)
137140

141+
# Handle long/double inputs if requested by the user
142+
if settings.truncate_long_and_double:
143+
submodule_inputs = repair_long_or_double_inputs(
144+
partitioned_module, submodule, submodule_inputs, name
145+
)
146+
138147
# Create TRT Module from submodule
139148
trt_mod = convert_module(
140149
submodule,

py/torch_tensorrt/dynamo/compile.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
VERSION_COMPATIBLE,
3131
OPTIMIZATION_LEVEL,
3232
USE_PYTHON_RUNTIME,
33+
TRUNCATE_LONG_AND_DOUBLE,
3334
)
3435

3536

@@ -53,7 +54,7 @@ def compile(
5354
dla_local_dram_size=1073741824,
5455
dla_global_dram_size=536870912,
5556
calibrator=None,
56-
truncate_long_and_double=False,
57+
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
5758
require_full_compilation=False,
5859
min_block_size=MIN_BLOCK_SIZE,
5960
torch_executed_ops=[],
@@ -109,6 +110,7 @@ def compile(
109110
"version_compatible": version_compatible,
110111
"optimization_level": optimization_level,
111112
"use_python_runtime": use_python_runtime,
113+
"truncate_long_and_double": truncate_long_and_double,
112114
}
113115

114116
settings = CompilationSettings(**compilation_options)
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .trt_interpreter import *
22
from .conversion import *
3+
from .truncate_long_and_double import repair_long_or_double_inputs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import torch
2+
from torch.fx.node import _get_qualified_name
3+
from typing import Optional, Sequence, Union
4+
5+
6+
def _extract_downstream_get_nodes(
7+
module_node: torch.fx.Node, output_indices: Sequence[int]
8+
) -> Sequence[torch.fx.Node]:
9+
"""Extracts downstream users of a node which get the item at a particular index
10+
11+
Certain module-type nodes have multiple outputs (tuple of outputs). This function
12+
returns downstream nodes which call the _operator.getitem function, which extracts
13+
the element at a particular index in the tuple
14+
15+
Args:
16+
module_node: FX module-type node to analyze
17+
output_index: Indices in the module node output to search for
18+
Returns:
19+
List of nodes which get the item at the specified index in the module node output
20+
"""
21+
get_nodes = []
22+
23+
# Iterate over all downstream users of the node object
24+
for user in module_node.users:
25+
# If the user is a "get" node accessing the specified index, store it
26+
if _get_qualified_name(user.target) == "_operator.getitem" and (
27+
user.args[1] in output_indices
28+
):
29+
get_nodes.append(user)
30+
31+
return get_nodes
32+
33+
34+
def _repair_64bit_input(
35+
gm: torch.fx.GraphModule,
36+
position: int,
37+
submodule_name: str,
38+
submodule_outputs: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]],
39+
dtype: torch.dtype,
40+
):
41+
"""Fixes a single Long/Double input to a TRT-accelerated subgraph
42+
43+
In-Place modifies the provided graph
44+
45+
Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
46+
inserts an upcast back to the 64-bit type for subsequent Torch operations
47+
48+
Args:
49+
gm: FX GraphModule enclosing the TRT subgraph
50+
position: Index in the submodule inputs at which the long or double input is found
51+
submodule_name: Name of TRT-accelerated subgraph module in FX graph
52+
submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
53+
dtype: Data type of tensor at position in submodule (double/long)
54+
"""
55+
assert dtype in (
56+
torch.int64,
57+
torch.float64,
58+
), f"dtype argument must be torch.int64 or torch.float64, got {dtype}"
59+
60+
# Determine target data type in 32 and 64 bit forms
61+
dtype_64bit = dtype
62+
dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32
63+
64+
# Find the node representing the submodule in the graph
65+
module_node = None
66+
67+
# Iterate over all nodes in the graph, seeking target module name match
68+
for n in gm.graph.nodes:
69+
if n.op == "call_module" and str(n.target) == submodule_name:
70+
module_node = n
71+
break
72+
73+
if module_node is None:
74+
raise AssertionError(
75+
f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}"
76+
)
77+
78+
# Extract the 64-bit node of the input
79+
node_64bit = module_node.all_input_nodes[position]
80+
81+
# Prior to the module, insert a cast to the 32-bit equivalent node
82+
with gm.graph.inserting_before(module_node):
83+
node_32bit = gm.graph.call_function(
84+
torch.ops.aten._to_copy.default,
85+
args=(node_64bit,),
86+
kwargs={"dtype": dtype_32bit},
87+
)
88+
89+
# Replace 64-bit input to TRT module with new 32-bit cast node
90+
module_node.replace_input_with(node_64bit, node_32bit)
91+
92+
output_positions_64bit = set()
93+
outputs_list = (
94+
[submodule_outputs]
95+
if isinstance(submodule_outputs, torch.Tensor)
96+
else submodule_outputs
97+
)
98+
99+
# Determine if any outputs of the model are 64-bit type and store their indices
100+
if submodule_outputs is not None:
101+
for output_position, output in enumerate(outputs_list):
102+
if output.dtype == dtype_64bit:
103+
output_positions_64bit.add(output_position)
104+
105+
# Only enter this code block if there exists a 64-bit output
106+
# This implies a cast is needed, since TRT cannot output 64-bit tensors
107+
if output_positions_64bit:
108+
# Determine whther the outputs of the module are tuple-type or not
109+
is_collection_output = False
110+
if isinstance(submodule_outputs, tuple):
111+
is_collection_output = True
112+
113+
if not is_collection_output:
114+
# If the output is a single tensor, insert a cast back to int64
115+
with gm.graph.inserting_after(module_node):
116+
cast_node_64bit = gm.graph.call_function(
117+
torch.ops.aten._to_copy.default,
118+
args=(module_node,),
119+
kwargs={"dtype": dtype_64bit},
120+
)
121+
122+
# Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent
123+
module_node.replace_all_uses_with(
124+
cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit)
125+
)
126+
127+
else:
128+
# If the output is a tuple of tensors, extract downstream users for each 64-bit output
129+
get_nodes = _extract_downstream_get_nodes(
130+
module_node, output_positions_64bit
131+
)
132+
133+
# For each downstream user, append a cast node back to the 64-bit precision
134+
for get_node in get_nodes:
135+
with gm.graph.inserting_after(get_node):
136+
cast_node_64bit = gm.graph.call_function(
137+
torch.ops.aten._to_copy.default,
138+
args=(get_node,),
139+
kwargs={"dtype": torch.int64},
140+
)
141+
142+
get_node.replace_all_uses_with(
143+
cast_node_64bit,
144+
delete_user_cb=lambda user: (user != cast_node_64bit),
145+
)
146+
147+
# Clean up graph and ensure invariants are preserved
148+
gm.graph.eliminate_dead_code()
149+
gm.graph.lint()
150+
gm.recompile()
151+
152+
153+
def repair_long_or_double_inputs(
154+
parent_graph: torch.fx.GraphModule,
155+
submodule: torch.fx.GraphModule,
156+
submodule_inputs: Sequence[torch.Tensor],
157+
submodule_name: Optional[str] = None,
158+
) -> Sequence[torch.Tensor]:
159+
"""Fixes all Long/Double type inputs to a TRT-accelerated subgraph
160+
161+
In-Place modifies the provided graph
162+
163+
Inserts a cast to the 32-bit equivalent type for TRT, then if necessary,
164+
inserts an upcast back to the 64-bit type for subsequent Torch operations
165+
166+
Args:
167+
parent_graph: FX GraphModule enclosing the TRT subgraph
168+
submodule: Child submodule to repair inputs on
169+
submodule_inputs: Input tensor(s) of TRT-accelerated subgraph (used for dtypes/structure)
170+
submodule_name: Optionally specify the name of the submodule target in the parent graph
171+
Returns:
172+
New submodule inputs, updated accordingly with long/double truncation
173+
"""
174+
num_submodule_inputs = len(submodule_inputs)
175+
repaired_outputs_once = False
176+
177+
# For each input to the TRT subgraph, check if its type is long/double
178+
for position in range(num_submodule_inputs):
179+
param = submodule_inputs[position]
180+
181+
# If the data type of the input is long/double, insert necessary
182+
# casts to replace the operation
183+
if param.dtype in (torch.int64, torch.float64):
184+
# Ensure outputs are only repaired once per submodule to avoid
185+
# unnecessary ops showing up in the graph
186+
if not repaired_outputs_once:
187+
submodule_outputs = submodule(*submodule_inputs)
188+
189+
_repair_64bit_input(
190+
parent_graph,
191+
position,
192+
submodule_name if submodule_name is not None else submodule._get_name(),
193+
None if repaired_outputs_once else submodule_outputs,
194+
param.dtype,
195+
)
196+
197+
repaired_outputs_once = True
198+
199+
# Repair submodule inputs in accordance with inserted casts
200+
dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32
201+
submodule_inputs = (
202+
submodule_inputs[:position]
203+
+ (param.to(dtype_32bit),)
204+
+ submodule_inputs[position + 1 :]
205+
)
206+
207+
return submodule_inputs

tests/py/dynamo/backend/test_backend_compiler.py

+113
Original file line numberDiff line numberDiff line change
@@ -171,5 +171,118 @@ def forward(self, x, y):
171171
)
172172

173173

174+
class Test64BitInput(TestCase):
175+
def test_float64_input_full_support(self):
176+
class FullySupportedMultiOp(torch.nn.Module):
177+
def forward(self, x, y):
178+
return torch.ops.aten.mean.dim(
179+
torch.ops.aten.mul.Tensor(torch.ops.aten.add.Tensor(x, y), 2), [0]
180+
)
181+
182+
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
183+
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)
184+
185+
self.assertEquals(
186+
len(list(partitioned_graph.named_children())),
187+
1,
188+
"All operators are supported, there should be one segment",
189+
)
190+
191+
inputs = [
192+
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
193+
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
194+
]
195+
196+
torch._dynamo.reset()
197+
198+
# Validate that the results between Torch and Torch-TRT are similar
199+
optimized_model = torch_tensorrt.compile(
200+
fx_graph,
201+
"torch_compile",
202+
inputs,
203+
min_block_size=1,
204+
pass_through_build_failures=True,
205+
truncate_long_and_double=True,
206+
debug=True,
207+
)
208+
optimized_model_results = optimized_model(*inputs).detach().cpu()
209+
torch_model_results = fx_graph(*inputs).detach().cpu()
210+
211+
max_diff = float(
212+
torch.max(torch.abs(optimized_model_results - torch_model_results))
213+
)
214+
self.assertAlmostEqual(
215+
max_diff,
216+
0,
217+
DECIMALS_OF_AGREEMENT,
218+
f"TRT outputs don't match with the original model.",
219+
)
220+
221+
def test_int64_input_partial_support(self):
222+
class PartiallySupportedMultiOp(torch.nn.Module):
223+
def forward(self, x, y):
224+
return torch.ops.aten.div.Tensor_mode(
225+
x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor"
226+
)
227+
228+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
229+
unexpected_ops = {torch.ops.aten.add.Tensor}
230+
231+
inputs = [
232+
torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(),
233+
torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(),
234+
]
235+
236+
(unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing(
237+
fx_graph,
238+
inputs,
239+
unexpected_ops=unexpected_ops,
240+
min_block_size=1,
241+
torch_executed_ops={"torch.ops.aten.add.Tensor"},
242+
testing_partitioning=True,
243+
)
244+
245+
self.assertEquals(
246+
len(unexpected_ops_seen),
247+
0,
248+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
249+
)
250+
self.assertEquals(
251+
len(partitioned_graphs),
252+
1,
253+
"Without control flow breaks, there should only be a single graph",
254+
)
255+
self.assertEquals(
256+
len(list(partitioned_graphs[0].named_children())),
257+
1,
258+
"Certain operators are set to run in Torch, expected 1 segment",
259+
)
260+
261+
torch._dynamo.reset()
262+
263+
# Validate that the results between Torch and Torch-TRT are similar
264+
optimized_model = torch_tensorrt.compile(
265+
fx_graph,
266+
"torch_compile",
267+
inputs,
268+
min_block_size=1,
269+
pass_through_build_failures=True,
270+
truncate_long_and_double=True,
271+
debug=True,
272+
)
273+
optimized_model_results = optimized_model(*inputs).detach().cpu()
274+
torch_model_results = fx_graph(*inputs).detach().cpu()
275+
276+
max_diff = float(
277+
torch.max(torch.abs(optimized_model_results - torch_model_results))
278+
)
279+
self.assertAlmostEqual(
280+
max_diff,
281+
0,
282+
DECIMALS_OF_AGREEMENT,
283+
f"TRT outputs don't match with the original model.",
284+
)
285+
286+
174287
if __name__ == "__main__":
175288
run_tests()

0 commit comments

Comments
 (0)