Skip to content

Commit 57215f1

Browse files
committed
feat: Prototype Module-Acceleration in Dynamo
- Add support for excluding entire Torch modules from tracing in Dynamo using Torch custom operators - Develop new dataclass to store required replacement functions and operators in a streamlined way - Add new registry to store mapping between replacement operators and their corresponding dataclass - Add documentation for easy additions of new module-level exclusion operators - Add robust testing and address recent review comments
1 parent e9ec251 commit 57215f1

File tree

9 files changed

+338
-7
lines changed

9 files changed

+338
-7
lines changed

py/torch_tensorrt/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _find_lib(name, paths):
9494

9595
from torch_tensorrt import fx
9696

97-
if version.parse(torch.__version__) >= version.parse("2.dev"):
97+
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
9898
from torch_tensorrt import dynamo
9999
from torch_tensorrt.dynamo import backend
100100

py/torch_tensorrt/dynamo/backend/backends.py

+12
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
99
get_decompositions,
1010
)
11+
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
12+
pre_aot_module_replacement,
13+
)
1114
from torch_tensorrt.dynamo.backend.lowering._partition import (
1215
partition,
1316
get_submod_inputs,
@@ -45,6 +48,13 @@ def aot_torch_tensorrt_aten_backend(
4548
settings=settings,
4649
)
4750

51+
logger.debug("Pre-module replacement graph:\n" + str(gm.graph))
52+
53+
# Perform Pre-AOT Lowering for Module-Level Replacement
54+
gm = pre_aot_module_replacement(gm)
55+
56+
logger.debug("Post-module replacement graph:\n" + str(gm.graph))
57+
4858
# Invoke AOTAutograd to translate operators to aten
4959
return aot_module_simplified(
5060
gm,
@@ -70,6 +80,8 @@ def _pretraced_backend(
7080
Compiled FX GraphModule
7181
"""
7282
try:
83+
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
84+
7385
trt_compiled = _compile_module(
7486
gm,
7587
sample_inputs,
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
1+
from ._decompositions import (
22
get_decompositions,
33
)
4-
from torch_tensorrt.dynamo.backend.lowering._partition import (
5-
partition,
6-
get_submod_inputs,
4+
from ._pre_aot_lowering import (
5+
MODULE_SUBSTITUTION_REGISTRY,
6+
module_substitution,
77
)
8+
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
9+
from .module_substitutions import *

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
2-
from typing import Dict, List, Optional, Sequence
2+
from typing import Dict, List, Optional, Sequence, Set
33

44
import torch
55

66
from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
7+
from torch_tensorrt.dynamo.backend.lowering import MODULE_SUBSTITUTION_REGISTRY
78
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
89
from torch.fx.graph_module import GraphModule
910
from torch.fx.node import _get_qualified_name
@@ -14,6 +15,11 @@
1415

1516
logger = logging.getLogger(__name__)
1617

18+
DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
19+
_get_qualified_name(module.new_operator)
20+
for module in MODULE_SUBSTITUTION_REGISTRY.values()
21+
)
22+
1723

1824
class TRTPartitioner(CapabilityBasedPartitioner):
1925
"""Partitioner to split an FX graph into subgraphs based on operator support
@@ -35,7 +41,9 @@ def __init__(
3541
operator_support: OperatorSupport,
3642
*,
3743
non_compute_ops: Optional[Sequence[str]] = None,
38-
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
44+
allowed_single_node_partition_ops: Optional[
45+
Sequence[str]
46+
] = DEFAULT_SINGLE_NODE_PARTITIONS,
3947
min_block_size=MIN_BLOCK_SIZE,
4048
) -> None:
4149
super().__init__(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Callable, Dict, Type
3+
import torch
4+
import logging
5+
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
@dataclass(frozen=True)
11+
class ModuleReplacement:
12+
"""Class to store key functionality for module replacement"""
13+
14+
# torch.ops.___ name for replacement function for module
15+
new_operator: torch._ops.OpOverload
16+
17+
# Function taking a containing graph, a submodule, and a 'call_module' node and returning
18+
# a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
19+
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
20+
subgraph_insertion_fn: Callable[
21+
[torch.fx.GraphModule, torch.nn.Module, torch.fx.Node], torch.fx.Node
22+
]
23+
24+
25+
# Dictionary mapping module to ModuleReplacement instance
26+
MODULE_SUBSTITUTION_REGISTRY: Dict[Type[torch.nn.Module], ModuleReplacement] = dict()
27+
28+
29+
def module_substitution(
30+
module_to_replace: Type[torch.nn.Module],
31+
new_operator: torch._ops.OpOverload,
32+
enabled: bool = True,
33+
) -> Callable[[Any], Any]:
34+
"""Decorator to register subgraph insertion functions
35+
36+
Args:
37+
module_to_replace: nn.Module to replace
38+
new_operator: Custom torch operator to replace with
39+
enabled: Whether the substitution is enabled or disabled
40+
Returns:
41+
torch.fx.GraphModule
42+
"""
43+
44+
def register_substitution(subgraph_insertion_fn):
45+
"""Function for use if substitution is enabled"""
46+
module_replacement = ModuleReplacement(
47+
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
48+
)
49+
MODULE_SUBSTITUTION_REGISTRY[module_to_replace] = module_replacement
50+
return subgraph_insertion_fn
51+
52+
def disable_substitution(subgraph_insertion_fn):
53+
"""Function for use if substitution is disabled"""
54+
return subgraph_insertion_fn
55+
56+
return register_substitution if enabled else disable_substitution
57+
58+
59+
def pre_aot_module_replacement(gm: torch.fx.GraphModule):
60+
"""Perform module-level graph replacement prior to AOT tracing
61+
62+
Args:
63+
gm: FX GraphModule to perform module replacement on
64+
Returns:
65+
torch.fx.GraphModule
66+
67+
"""
68+
# Ensure all parameters are in inference mode
69+
for param in gm.parameters():
70+
param.requires_grad = False
71+
72+
# Iterate over graph nodes, extracting module calls, to check for interceptions
73+
for n in gm.graph.nodes:
74+
if n.op == "call_module":
75+
# Extract submodule from graph
76+
submodule = gm.get_submodule(n.target)
77+
78+
# If submodule is a member of the substitution registry, replace it
79+
if type(submodule) in MODULE_SUBSTITUTION_REGISTRY:
80+
81+
try:
82+
replacement = MODULE_SUBSTITUTION_REGISTRY[type(submodule)]
83+
op, insertion_fn = (
84+
replacement.new_operator,
85+
replacement.subgraph_insertion_fn,
86+
)
87+
logger.debug(
88+
f"Replacing module of type {type(submodule)} with {op}"
89+
)
90+
91+
# Insert new node prior to older node
92+
with gm.graph.inserting_before(n):
93+
new_node = insertion_fn(gm, submodule, n)
94+
95+
# If submodule is not a native torch.nn module, it must be manually excluded
96+
# from Dynamo tracing
97+
if not type(submodule).__module__.startswith("torch.nn"):
98+
torch._dynamo.allowed_functions._allowed_function_ids.add(
99+
id(type(submodule))
100+
)
101+
102+
# Replace all original node uses and clean up graph
103+
n.replace_all_uses_with(new_node)
104+
gm.graph.eliminate_dead_code()
105+
gm.graph.lint()
106+
gm.recompile()
107+
108+
# A module replacement can fail in the event that the specific instance of the submodule cannot
109+
# be replaced
110+
except Exception:
111+
logger.debug(
112+
f"Encountered error while replacing {type(submodule)}",
113+
exc_info=True,
114+
)
115+
continue
116+
117+
# Perform cleanup and recompilation before returning module
118+
gm.graph.eliminate_dead_code()
119+
gm.graph.lint()
120+
gm.recompile()
121+
return gm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .maxpool1d import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from typing import Dict, Tuple
2+
import torch
3+
from torch._custom_op.impl import custom_op
4+
from torch.fx.node import Argument, Target
5+
6+
from torch_tensorrt.fx.converter_registry import tensorrt_converter
7+
from torch_tensorrt.fx.converters import acc_ops_converters
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
10+
from torch_tensorrt.dynamo.backend.lowering import module_substitution
11+
12+
13+
# This file serves as an example and a tutorial for excluding custom modules from
14+
# torch.compile tracing. Each required step is labeled with a number indicating the
15+
# preferable implementation order.
16+
17+
18+
# 1. The Placeholder
19+
#
20+
# Specify the schema and namespace of the operator, as well as a placeholder function
21+
# representing the schema. The schema should be in torch JIT syntax, indicating input and output
22+
# types. The namespace, such as tensorrt, will cause the op to be registered as torch.ops.tensorrt.your_op
23+
# Then, create a placeholder function with no operations, but having the same schema and naming as that
24+
# used in the decorator
25+
@custom_op(
26+
qualname="tensorrt::maxpool1d",
27+
manual_schema="(Tensor x, int[1] kernel_size, int[1] stride, int[1] padding, int[1] dilation, bool ceil_mode) -> Tensor",
28+
)
29+
def maxpool1d(x, kernel_size, stride, padding, dilation, ceil_mode):
30+
# Defines operator schema, name, namespace, and function header
31+
...
32+
33+
34+
# 2. The Generic Implementation
35+
#
36+
# Define the default implementation of the operator in torch syntax. This is used for autograd
37+
# and other tracing functionality. Generally, the torch.nn.functional analog of the operator to replace
38+
# is desirable. If the operator to replace is a custom module you've written, then add its Torch
39+
# implementation here. Note that the function header to the generic function can have specific arguments
40+
# as in the above placeholder
41+
@maxpool1d.impl("cpu")
42+
@maxpool1d.impl("cuda")
43+
@maxpool1d.impl_abstract()
44+
def maxpool1d_generic(
45+
*args,
46+
**kwargs,
47+
):
48+
# Defines an implementation for AOT Autograd to use for shape analysis/propagation
49+
return torch.nn.functional.max_pool1d(
50+
*args,
51+
**kwargs,
52+
)
53+
54+
55+
# 3. The Module Substitution Function
56+
#
57+
# Define a function which can intercept a node of the kind to be replaced, extract
58+
# the relevant data from that node/submodule, and then re-package the information
59+
# for use by an accelerated implementation (to be implemented in step 4). This function
60+
# should use the operator defined in step 1 (for example torch.ops.tensorrt.maxpool1d).
61+
# It should refactor the args and kwargs as is needed by the accelerated implementation.
62+
#
63+
# If the submodule has weights or other Tensor fields which the accelerated implementation
64+
# needs, the function should insert the necessary nodes to access those weights. For example,
65+
# if the weight Tensor of a submodule is needed, one could write:
66+
#
67+
# weights = gm.graph.get_attr(n.target + ".weight", torch.Tensor)
68+
# bias = gm.graph.get_attr(n.target + ".bias", torch.Tensor)
69+
# ...
70+
# kwargs={"weight": weights,
71+
# "bias": bias,
72+
# ...
73+
#
74+
@module_substitution(torch.nn.MaxPool1d, torch.ops.tensorrt.maxpool1d)
75+
def maxpool1d_insertion_fn(
76+
gm: torch.fx.GraphModule, submodule: torch.nn.Module, node: torch.fx.Node
77+
) -> torch.fx.Node:
78+
# Defines insertion function for new node
79+
new_node = gm.graph.call_function(
80+
torch.ops.tensorrt.maxpool1d,
81+
args=node.args,
82+
kwargs={
83+
"kernel_size": submodule.kernel_size,
84+
"stride": submodule.stride,
85+
"padding": submodule.padding,
86+
"dilation": submodule.dilation,
87+
"ceil_mode": submodule.ceil_mode,
88+
},
89+
)
90+
91+
return new_node
92+
93+
94+
# 4. The Accelerated Implementation
95+
#
96+
# Define an accelerated implementation of the operator, and register it as necessary.
97+
# This accelerated implementation should consume the args/kwargs specified in step 3.
98+
# One should expect that torch.compile will compress all kwargs into the args field in
99+
# the order specified in the schema written in step 1.
100+
@tensorrt_converter(torch.ops.tensorrt.maxpool1d.default)
101+
def tensorrt_maxpool1d(
102+
network: TRTNetwork,
103+
target: Target,
104+
args: Tuple[Argument, ...],
105+
kwargs: Dict[str, Argument],
106+
name: str,
107+
) -> TRTTensor:
108+
# Defines converter replacing the default operator for this function
109+
kwargs_new = {
110+
"input": args[0],
111+
"kernel_size": args[1],
112+
"stride": args[2],
113+
"padding": args[3],
114+
"dilation": args[4],
115+
"ceil_mode": False if len(args) < 6 else args[5],
116+
}
117+
118+
return acc_ops_converters.acc_ops_max_pool1d(
119+
network, target, None, kwargs_new, name
120+
)
121+
122+
123+
# 5. Add Imports
124+
#
125+
# Add your accelerated module file to the __init__.py in this directory, to ensure
126+
# all registrations are run. For instance, if the new module file is called new_mod.py,
127+
# one should add `from .new_mod import *` to the __init__.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
from utils import lower_graph_testing
3+
from torch.testing._internal.common_utils import run_tests, TestCase
4+
from torch_tensorrt.dynamo import compile
5+
6+
7+
class TestMaxPool1D(TestCase):
8+
def test_pre_aot_lowering_maxpool1d(self):
9+
class MaxPool1D(torch.nn.Module):
10+
def __init__(self, *args, **kwargs) -> None:
11+
super().__init__(*args, **kwargs)
12+
self.maxpool = torch.nn.MaxPool1d(2)
13+
14+
def forward(self, x):
15+
return self.maxpool(x)
16+
17+
# Operations expected to be included in the traced graph after decompositions
18+
expected_ops = {torch.ops.tensorrt.maxpool1d.default}
19+
20+
inputs = [
21+
torch.rand(
22+
9,
23+
16,
24+
2,
25+
).cuda(),
26+
]
27+
28+
fx_graph = torch.fx.symbolic_trace(MaxPool1D())
29+
_, expected_ops_unseen = lower_graph_testing(
30+
fx_graph, inputs, expected_ops=expected_ops, min_block_size=1
31+
)
32+
33+
self.assertEquals(
34+
len(expected_ops_unseen),
35+
0,
36+
f"The following expected ops were not encountered: {expected_ops_unseen}",
37+
)
38+
39+
torch._dynamo.reset()
40+
41+
# Validate that the results between Torch and Torch-TRT are similar
42+
optimized_model = compile(
43+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
44+
)
45+
optimized_model_results = optimized_model(*inputs).detach().cpu()
46+
torch_model_results = fx_graph(*inputs).detach().cpu()
47+
48+
max_diff = torch.max(torch.abs(optimized_model_results - torch_model_results))
49+
self.assertAlmostEqual(
50+
max_diff, 0, f"Maxpool1d TRT outputs don't match with the original model."
51+
)
52+
53+
54+
if __name__ == "__main__":
55+
run_tests()

0 commit comments

Comments
 (0)