Skip to content

Commit 2eaf7ca

Browse files
committed
fix: Reorganize Dynamo backends
- Rename key backends to establish default backend and optional alternatives - Update function headers and docstrings, as well as key imports - Rename `torch_compile` folder to `backends` in accordance with `torch.compile` terminology - Update references throughout codebase - Specify certain functions as private/helper via underscore
1 parent d4e5ed0 commit 2eaf7ca

20 files changed

+96
-58
lines changed

.circleci/config.yml

+11-11
Original file line numberDiff line numberDiff line change
@@ -763,33 +763,33 @@ commands:
763763
- store_artifacts:
764764
path: /tmp/testlogs
765765

766-
test-dynamo-torch_compile-core:
767-
description: "Test the Dynamo torch_compile path"
766+
test-dynamo-compile-core:
767+
description: "Test the Dynamo compile path"
768768
steps:
769769
- run:
770-
name: Run Dynamo torch_compile core tests
770+
name: Run Dynamo compile core tests
771771
command: |
772-
cd py/torch_tensorrt/dynamo/torch_compile
772+
cd py/torch_tensorrt/dynamo/backend
773773
pushd test/
774-
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
774+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml
775775
popd
776776
777777
- store_test_results:
778778
path: /tmp/artifacts
779779
- store_artifacts:
780780
path: /tmp/testlogs
781781

782-
test-dynamo-torch_compile:
783-
description: "Test the Dynamo torch_compile path"
782+
test-dynamo-compile:
783+
description: "Test the Dynamo compile path"
784784
steps:
785785
- run:
786-
name: Run Dynamo torch_compile E2E tests
786+
name: Run Dynamo compile E2E tests
787787
command: |
788788
cd py/torch_tensorrt/dynamo/
789789
pushd test/
790790
pip3 install timm
791791
pip3 install transformers
792-
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml --ir torch_compile
792+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo_compile
793793
popd
794794
795795
- store_test_results:
@@ -1051,8 +1051,8 @@ jobs:
10511051
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
10521052
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
10531053
- dump-test-env
1054-
- test-dynamo-torch_compile
1055-
- test-dynamo-torch_compile-core
1054+
- test-dynamo-compile
1055+
- test-dynamo-compile-core
10561056
- test-dynamo-fx_ts
10571057

10581058
package-x86_64-linux:

py/torch_tensorrt/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _find_lib(name, paths):
9797

9898
if version.parse(torch.__version__) >= version.parse("2.dev"):
9999
from torch_tensorrt import dynamo
100-
from torch_tensorrt.dynamo import torch_compile
100+
from torch_tensorrt.dynamo import backend
101101

102102

103103
def _register_with_torch():

py/torch_tensorrt/_compile.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class _IRType(Enum):
1616
ts = 0
1717
fx = 1
1818
fx_ts_compat = 2
19-
torch_compile = 3
19+
dynamo_compile = 3
2020

2121

2222
class _ModuleType(Enum):
@@ -47,7 +47,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
4747

4848
ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
4949
ir_targets_fx = ir == "fx"
50-
ir_targets_torch_compile = ir == "torch_compile"
50+
ir_targets_dynamo_compile = ir == "dynamo_compile"
5151
ir_targets_fx_ts_compat = ir == "fx_ts_compat"
5252

5353
if module_is_tsable and ir_targets_torchscript:
@@ -56,8 +56,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
5656
return _IRType.fx
5757
elif module_is_fxable and ir_targets_fx_ts_compat:
5858
return _IRType.fx_ts_compat
59-
elif module_is_fxable and ir_targets_torch_compile:
60-
return _IRType.torch_compile
59+
elif module_is_fxable and ir_targets_dynamo_compile:
60+
return _IRType.dynamo_compile
6161
else:
6262
if ir == "default":
6363
# Options are listed in order of preference
@@ -156,8 +156,8 @@ def compile(
156156
dynamic_batch=False,
157157
**kwargs,
158158
)
159-
elif target_ir == _IRType.torch_compile:
160-
return torch_tensorrt.dynamo.torch_compile(
159+
elif target_ir == _IRType.dynamo_compile:
160+
return torch_tensorrt.dynamo.compile(
161161
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
162162
)
163163
elif target_ir == _IRType.fx_ts_compat:

py/torch_tensorrt/dynamo/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from torch_tensorrt.dynamo import fx_ts_compat
2-
from .torch_compile import compile as torch_compile
2+
from .backend import compile

py/torch_tensorrt/dynamo/torch_compile/__init__.py renamed to py/torch_tensorrt/dynamo/backend/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

11-
from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
12-
from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device
13-
from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend
14-
from torch_tensorrt.dynamo.torch_compile._defaults import (
11+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
12+
from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
13+
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
14+
from torch_tensorrt.dynamo.backend._defaults import (
1515
PRECISION,
1616
DEBUG,
1717
MAX_WORKSPACE_SIZE,
@@ -121,6 +121,6 @@ def create_backend(
121121
)
122122

123123
return partial(
124-
tensorrt_backend,
124+
torch_tensorrt_backend,
125125
settings=settings,
126126
)

py/torch_tensorrt/dynamo/torch_compile/_settings.py renamed to py/torch_tensorrt/dynamo/backend/_settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22

33
from torch_tensorrt.fx.utils import LowerPrecision
4-
from torch_tensorrt.dynamo.torch_compile._defaults import (
4+
from torch_tensorrt.dynamo.backend._defaults import (
55
PRECISION,
66
DEBUG,
77
MAX_WORKSPACE_SIZE,

py/torch_tensorrt/dynamo/torch_compile/backends.py renamed to py/torch_tensorrt/dynamo/backend/backends.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,42 @@
44
from functools import partial
55
import torch._dynamo as td
66

7-
from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
8-
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
7+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
8+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
99
get_decompositions,
1010
)
11-
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
11+
from torch_tensorrt.dynamo.backend.lowering._partition import (
1212
partition,
1313
get_submod_inputs,
1414
)
15-
from torch_tensorrt.dynamo.torch_compile.conversion import convert_module
15+
from torch_tensorrt.dynamo.backend.conversion import convert_module
1616

1717
from torch._dynamo.backends.common import fake_tensor_unsupported
1818

1919
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2020

2121

22-
@td.register_backend(name="tensorrt")
22+
@td.register_backend(name="torch_tensorrt")
2323
@fake_tensor_unsupported
24-
def tensorrt_backend(
25-
gm: torch.nn.Module,
24+
def torch_tensorrt_backend(
25+
gm: torch.fx.GraphModule,
26+
sample_inputs: Sequence[torch.Tensor],
27+
settings: CompilationSettings = CompilationSettings(),
28+
):
29+
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
30+
31+
return DEFAULT_BACKEND(gm=gm, sample_inputs=sample_inputs, settings=settings)
32+
33+
34+
@td.register_backend(name="aot_torch_tensorrt_aten")
35+
@fake_tensor_unsupported
36+
def aot_torch_tensorrt_aten_backend(
37+
gm: torch.fx.GraphModule,
2638
sample_inputs: Sequence[torch.Tensor],
2739
settings: CompilationSettings = CompilationSettings(),
2840
):
2941
custom_backend = partial(
30-
fx_dynamo_backend,
42+
_pretraced_backend,
3143
settings=settings,
3244
)
3345

@@ -40,14 +52,12 @@ def tensorrt_backend(
4052
)
4153

4254

43-
@td.register_backend(name="fx_tensorrt")
44-
@fake_tensor_unsupported
45-
def fx_dynamo_backend(
55+
def _pretraced_backend(
4656
gm: torch.fx.GraphModule,
47-
example_inputs: Sequence[torch.Tensor],
57+
sample_inputs: Sequence[torch.Tensor],
4858
settings: CompilationSettings = CompilationSettings(),
4959
):
50-
"""Helper function to manage translation of FX module to TRT engines
60+
"""Helper function to manage translation of traced FX module to TRT engines
5161
5262
Args:
5363
module: FX GraphModule to convert
@@ -57,9 +67,9 @@ def fx_dynamo_backend(
5767
Compiled FX GraphModule
5868
"""
5969
try:
60-
trt_compiled = compile_module(
70+
trt_compiled = _compile_module(
6171
gm,
62-
example_inputs,
72+
sample_inputs,
6373
settings=settings,
6474
)
6575
return trt_compiled
@@ -72,12 +82,12 @@ def fx_dynamo_backend(
7282
return gm.forward
7383

7484

75-
def compile_module(
85+
def _compile_module(
7686
gm: torch.fx.GraphModule,
77-
example_inputs: Sequence[torch.Tensor],
87+
sample_inputs: Sequence[torch.Tensor],
7888
settings: CompilationSettings = CompilationSettings(),
7989
) -> torch.fx.GraphModule:
80-
"""Compile an FX module
90+
"""Compile a traced FX module
8191
8292
Includes: Partitioning + Conversion Phases
8393
@@ -100,7 +110,7 @@ def compile_module(
100110

101111
# Get submodule inputs
102112
submodule_inputs = get_submod_inputs(
103-
partitioned_module, submodule, example_inputs
113+
partitioned_module, submodule, sample_inputs
104114
)
105115

106116
# Create TRT Module from submodule
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
2+
get_decompositions,
3+
)
4+
from torch_tensorrt.dynamo.backend.lowering._partition import (
5+
partition,
6+
get_submod_inputs,
7+
)

py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py renamed to py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES
5+
from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES
66
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
77
from torch.fx.passes.operator_support import OperatorSupport
88

py/torch_tensorrt/dynamo/torch_compile/test/test_compiler_utils.py renamed to py/torch_tensorrt/dynamo/backend/test/test_compiler_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs
1+
from torch_tensorrt.dynamo.backend.utils import prepare_device, prepare_inputs
22
from utils import same_output_format
33
import torch_tensorrt
44
import unittest

py/torch_tensorrt/dynamo/torch_compile/test/test_partitioning.py renamed to py/torch_tensorrt/dynamo/backend/test/test_partitioning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch_tensorrt.dynamo.torch_compile.lowering import partition
1+
from torch_tensorrt.dynamo.backend.lowering import partition
22
from torch.testing._internal.common_utils import run_tests, TestCase
33
import torch
44
from copy import deepcopy
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from torch.testing._internal.common_utils import run_tests, TestCase
2+
import torch
3+
import torch_tensorrt
4+
import torchvision.models as models
5+
from torch_tensorrt.dynamo.test.utils import COSINE_THRESHOLD, cosine_similarity
6+
7+
8+
class TestResNet18(TestCase):
9+
def test_resnet18(ir):
10+
model = models.resnet18(pretrained=True).eval().to("cuda")
11+
input_ = torch.randn((1, 3, 224, 224)).to("cuda")
12+
13+
compile_spec = {
14+
"inputs": [input_],
15+
"enabled_precisions": {torch.float},
16+
"pass_through_build_failures": True,
17+
}
18+
19+
trt_mod = torch_tensorrt.dynamo.backend(model, **compile_spec)
20+
cos_sim = cosine_similarity(model(input_), trt_mod(input_))
21+
assert (
22+
cos_sim > COSINE_THRESHOLD,
23+
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
24+
)
25+
26+
27+
if __name__ == "__main__":
28+
run_tests()

py/torch_tensorrt/dynamo/torch_compile/test/utils.py renamed to py/torch_tensorrt/dynamo/backend/test/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from functools import partial
33
from typing import List, Sequence
44
import torch
5-
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
5+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
66
get_decompositions,
77
)
8-
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
8+
from torch_tensorrt.dynamo.backend.lowering._partition import (
99
partition,
1010
)
1111

py/torch_tensorrt/dynamo/torch_compile/utils.py renamed to py/torch_tensorrt/dynamo/backend/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def prepare_inputs(
4545

4646
else:
4747
raise ValueError(
48-
f"Invalid input type {type(inputs)} encountered in the torch_compile input parsing. "
48+
f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. "
4949
+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
5050
)
5151

py/torch_tensorrt/dynamo/test/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def pytest_addoption(parser):
99
type=str,
1010
required=True,
1111
help="IR to compile with",
12-
choices=["torch_compile", "fx_ts_compat"],
12+
choices=["dynamo_compile", "fx_ts_compat"],
1313
)
1414

1515

py/torch_tensorrt/dynamo/torch_compile/lowering/__init__.py

-7
This file was deleted.

0 commit comments

Comments
 (0)