Skip to content

Commit 7d658d6

Browse files
committed
fix: Reorganize Dynamo backends
- Rename key backends to establish default backend and optional alternatives - Update function headers and docstrings, as well as key imports - Rename `torch_compile` folder to `backends` in accordance with `torch.compile` terminology - Update references throughout codebase - Specify certain functions as private/helper via underscore
1 parent d4e5ed0 commit 7d658d6

17 files changed

+77
-39
lines changed

py/torch_tensorrt/_compile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def compile(
157157
**kwargs,
158158
)
159159
elif target_ir == _IRType.torch_compile:
160-
return torch_tensorrt.dynamo.torch_compile(
160+
return torch_tensorrt.dynamo.compile(
161161
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
162162
)
163163
elif target_ir == _IRType.fx_ts_compat:

py/torch_tensorrt/dynamo/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from torch_tensorrt.dynamo import fx_ts_compat
2-
from .torch_compile import compile as torch_compile
2+
from .backend import compile

py/torch_tensorrt/dynamo/torch_compile/__init__.py renamed to py/torch_tensorrt/dynamo/backend/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

11-
from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
12-
from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device
13-
from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend
14-
from torch_tensorrt.dynamo.torch_compile._defaults import (
11+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
12+
from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
13+
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
14+
from torch_tensorrt.dynamo.backend._defaults import (
1515
PRECISION,
1616
DEBUG,
1717
MAX_WORKSPACE_SIZE,
@@ -121,6 +121,6 @@ def create_backend(
121121
)
122122

123123
return partial(
124-
tensorrt_backend,
124+
torch_tensorrt_backend,
125125
settings=settings,
126126
)

py/torch_tensorrt/dynamo/torch_compile/_settings.py renamed to py/torch_tensorrt/dynamo/backend/_settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22

33
from torch_tensorrt.fx.utils import LowerPrecision
4-
from torch_tensorrt.dynamo.torch_compile._defaults import (
4+
from torch_tensorrt.dynamo.backend._defaults import (
55
PRECISION,
66
DEBUG,
77
MAX_WORKSPACE_SIZE,

py/torch_tensorrt/dynamo/torch_compile/backends.py renamed to py/torch_tensorrt/dynamo/backend/backends.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,42 @@
44
from functools import partial
55
import torch._dynamo as td
66

7-
from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
8-
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
7+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
8+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
99
get_decompositions,
1010
)
11-
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
11+
from torch_tensorrt.dynamo.backend.lowering._partition import (
1212
partition,
1313
get_submod_inputs,
1414
)
15-
from torch_tensorrt.dynamo.torch_compile.conversion import convert_module
15+
from torch_tensorrt.dynamo.backend.conversion import convert_module
1616

1717
from torch._dynamo.backends.common import fake_tensor_unsupported
1818

1919
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2020

2121

22-
@td.register_backend(name="tensorrt")
22+
@td.register_backend(name="torch_tensorrt")
2323
@fake_tensor_unsupported
24-
def tensorrt_backend(
25-
gm: torch.nn.Module,
24+
def torch_tensorrt_backend(
25+
gm: torch.fx.GraphModule,
26+
sample_inputs: Sequence[torch.Tensor],
27+
settings: CompilationSettings = CompilationSettings(),
28+
):
29+
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
30+
31+
return DEFAULT_BACKEND(gm=gm, sample_inputs=sample_inputs, settings=settings)
32+
33+
34+
@td.register_backend(name="aot_torch_tensorrt_aten")
35+
@fake_tensor_unsupported
36+
def aot_torch_tensorrt_aten_backend(
37+
gm: torch.fx.GraphModule,
2638
sample_inputs: Sequence[torch.Tensor],
2739
settings: CompilationSettings = CompilationSettings(),
2840
):
2941
custom_backend = partial(
30-
fx_dynamo_backend,
42+
_pretraced_backend,
3143
settings=settings,
3244
)
3345

@@ -40,14 +52,12 @@ def tensorrt_backend(
4052
)
4153

4254

43-
@td.register_backend(name="fx_tensorrt")
44-
@fake_tensor_unsupported
45-
def fx_dynamo_backend(
55+
def _pretraced_backend(
4656
gm: torch.fx.GraphModule,
47-
example_inputs: Sequence[torch.Tensor],
57+
sample_inputs: Sequence[torch.Tensor],
4858
settings: CompilationSettings = CompilationSettings(),
4959
):
50-
"""Helper function to manage translation of FX module to TRT engines
60+
"""Helper function to manage translation of traced FX module to TRT engines
5161
5262
Args:
5363
module: FX GraphModule to convert
@@ -57,9 +67,9 @@ def fx_dynamo_backend(
5767
Compiled FX GraphModule
5868
"""
5969
try:
60-
trt_compiled = compile_module(
70+
trt_compiled = _compile_module(
6171
gm,
62-
example_inputs,
72+
sample_inputs,
6373
settings=settings,
6474
)
6575
return trt_compiled
@@ -72,12 +82,12 @@ def fx_dynamo_backend(
7282
return gm.forward
7383

7484

75-
def compile_module(
85+
def _compile_module(
7686
gm: torch.fx.GraphModule,
77-
example_inputs: Sequence[torch.Tensor],
87+
sample_inputs: Sequence[torch.Tensor],
7888
settings: CompilationSettings = CompilationSettings(),
7989
) -> torch.fx.GraphModule:
80-
"""Compile an FX module
90+
"""Compile a traced FX module
8191
8292
Includes: Partitioning + Conversion Phases
8393
@@ -100,7 +110,7 @@ def compile_module(
100110

101111
# Get submodule inputs
102112
submodule_inputs = get_submod_inputs(
103-
partitioned_module, submodule, example_inputs
113+
partitioned_module, submodule, sample_inputs
104114
)
105115

106116
# Create TRT Module from submodule
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
2+
get_decompositions,
3+
)
4+
from torch_tensorrt.dynamo.backend.lowering._partition import (
5+
partition,
6+
get_submod_inputs,
7+
)

py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py renamed to py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES
5+
from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES
66
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
77
from torch.fx.passes.operator_support import OperatorSupport
88

py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py renamed to py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs
1+
from torch_tensorrt.dynamo.backend.utils import prepare_device, prepare_inputs
22
from utils import same_output_format
33
import torch_tensorrt
44
import unittest

py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py renamed to py/torch_tensorrt/dynamo/backend/test/test_partitioning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch_tensorrt.dynamo.torch_compile.lowering import partition
1+
from torch_tensorrt.dynamo.backend.lowering import partition
22
from torch.testing._internal.common_utils import run_tests, TestCase
33
import torch
44
from copy import deepcopy
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from torch.testing._internal.common_utils import run_tests, TestCase
2+
import torch
3+
import torch_tensorrt
4+
import torchvision.models as models
5+
from torch_tensorrt.dynamo.test.utils import COSINE_THRESHOLD, cosine_similarity
6+
7+
8+
class TestResNet18(TestCase):
9+
def test_resnet18(ir):
10+
model = models.resnet18(pretrained=True).eval().to("cuda")
11+
input_ = torch.randn((1, 3, 224, 224)).to("cuda")
12+
13+
compile_spec = {
14+
"inputs": [input_],
15+
"enabled_precisions": {torch.float},
16+
"pass_through_build_failures": True,
17+
}
18+
19+
trt_mod = torch_tensorrt.dynamo.backend(model, **compile_spec)
20+
cos_sim = cosine_similarity(model(input_), trt_mod(input_))
21+
assert (
22+
cos_sim > COSINE_THRESHOLD,
23+
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
24+
)
25+
26+
27+
if __name__ == "__main__":
28+
run_tests()

py/torch_tensorrt/dynamo/torch_compile/test/utils.py renamed to py/torch_tensorrt/dynamo/backend/test/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from functools import partial
33
from typing import List, Sequence
44
import torch
5-
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
5+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
66
get_decompositions,
77
)
8-
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
8+
from torch_tensorrt.dynamo.backend.lowering._partition import (
99
partition,
1010
)
1111

py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py

-7
This file was deleted.

0 commit comments

Comments
 (0)