Skip to content

Commit ea64529

Browse files
committed
chore/fix: Restructure Dynamo directory
- Add `common` directory which stores code common to both the compile and export path, to reduce code duplication and better organize the repository - Update necessary imports, fix minor argument pass-through issues in `fx_ts_compat`
1 parent 2003b07 commit ea64529

20 files changed

+47
-24
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

11-
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
11+
from torch_tensorrt.dynamo.common import CompilationSettings
1212
from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
1313
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
14-
from torch_tensorrt.dynamo.backend._defaults import (
14+
from torch_tensorrt.dynamo.common._defaults import (
1515
PRECISION,
1616
DEBUG,
1717
WORKSPACE_SIZE,

py/torch_tensorrt/dynamo/backend/backends.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import replace, fields
66
import torch._dynamo as td
77

8-
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
8+
from torch_tensorrt.dynamo.common import CompilationSettings
99
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
1010
get_decompositions,
1111
)

py/torch_tensorrt/dynamo/backend/conversion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import io
44
from torch_tensorrt.fx.trt_module import TRTModule
55
from torch_tensorrt import TRTModuleNext
6-
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
7-
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
6+
from torch_tensorrt.dynamo.common import (
7+
CompilationSettings,
88
InputTensorSpec,
99
TRTInterpreter,
1010
)

py/torch_tensorrt/dynamo/backend/lowering/_partition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
6+
from torch_tensorrt.dynamo.common._defaults import MIN_BLOCK_SIZE
77
from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY
88
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
99
from torch.fx.graph_module import GraphModule

py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from copy import deepcopy
55
from torch_tensorrt.dynamo import compile
66
from utils import lower_graph_testing
7-
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT
7+
from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT
88

99

1010
class TestTRTModuleNextCompilation(TestCase):

py/torch_tensorrt/dynamo/backend/test/test_decompositions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.testing._internal.common_utils import run_tests, TestCase
44
import torch
55
from torch_tensorrt.dynamo import compile
6-
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT
6+
from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT
77

88

99
class TestLowering(TestCase):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from ._settings import CompilationSettings
2+
3+
from .fx2trt import TRTInterpreter, TRTInterpreterResult
4+
from .input_tensor_spec import InputTensorSpec

py/torch_tensorrt/dynamo/backend/_settings.py renamed to py/torch_tensorrt/dynamo/common/_settings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Optional, Sequence
33

44
from torch_tensorrt.fx.utils import LowerPrecision
5-
from torch_tensorrt.dynamo.backend._defaults import (
5+
from ._defaults import (
66
PRECISION,
77
DEBUG,
88
WORKSPACE_SIZE,

py/torch_tensorrt/dynamo/common_utils/__init__.py

Whitespace-only changes.

py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
NO_IMPLICIT_BATCH_DIM_SUPPORT,
77
tensorrt_converter,
88
)
9-
from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa
10-
from .input_tensor_spec import InputTensorSpec # noqa
119
from .lower_setting import LowerSetting # noqa
1210
from .lower import compile # usort: skip #noqa
1311

py/torch_tensorrt/dynamo/fx_ts_compat/lower.py

+29-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
1111
from torch.fx.passes.splitter_base import SplitResult
1212

13-
from .fx2trt import TRTInterpreter, TRTInterpreterResult
13+
from torch_tensorrt.dynamo.common import TRTInterpreter, TRTInterpreterResult
1414
from .lower_setting import LowerSetting
1515
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
1616
from .passes.pass_utils import PassFunc, validate_inference
@@ -21,6 +21,17 @@
2121
from torch_tensorrt.fx.trt_module import TRTModule
2222
from torch_tensorrt.fx.utils import LowerPrecision
2323
from torch_tensorrt._Device import Device
24+
from torch_tensorrt.dynamo.common._defaults import (
25+
PRECISION,
26+
DEBUG,
27+
WORKSPACE_SIZE,
28+
MIN_BLOCK_SIZE,
29+
PASS_THROUGH_BUILD_FAILURES,
30+
MAX_AUX_STREAMS,
31+
VERSION_COMPATIBLE,
32+
OPTIMIZATION_LEVEL,
33+
USE_EXPERIMENTAL_RT,
34+
)
2435

2536
logger = logging.getLogger(__name__)
2637

@@ -34,21 +45,25 @@ def compile(
3445
disable_tf32=False,
3546
sparse_weights=False,
3647
enabled_precisions=set(),
37-
min_block_size: int = 3,
38-
workspace_size=0,
48+
min_block_size: int = MIN_BLOCK_SIZE,
49+
workspace_size=WORKSPACE_SIZE,
3950
dla_sram_size=1048576,
4051
dla_local_dram_size=1073741824,
4152
dla_global_dram_size=536870912,
4253
calibrator=None,
4354
truncate_long_and_double=False,
4455
require_full_compilation=False,
45-
debug=False,
56+
explicit_batch_dimension=False,
57+
debug=DEBUG,
4658
refit=False,
4759
timing_cache_prefix="",
4860
save_timing_cache=False,
4961
cuda_graph_batch_size=-1,
5062
is_aten=False,
51-
use_experimental_fx_rt=False,
63+
use_experimental_rt=USE_EXPERIMENTAL_RT,
64+
max_aux_streams=MAX_AUX_STREAMS,
65+
version_compatible=VERSION_COMPATIBLE,
66+
optimization_level=OPTIMIZATION_LEVEL,
5267
num_avg_timing_iters=1,
5368
torch_executed_ops=[],
5469
torch_executed_modules=[],
@@ -67,11 +82,14 @@ def compile(
6782
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
6883
save_timing_cache: Update timing cache with current timing cache data if set to True.
6984
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
70-
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
85+
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
86+
max_aux_streams: max number of aux stream to use
87+
version_compatible: enable version compatible feature
88+
optimization_level: builder optimization level
7189
Returns:
7290
A torch.nn.Module lowered by TensorRT.
7391
"""
74-
if use_experimental_fx_rt and not explicit_batch_dimension:
92+
if use_experimental_rt and not explicit_batch_dimension:
7593
raise ValueError(
7694
"The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True"
7795
)
@@ -122,7 +140,10 @@ def compile(
122140
save_timing_cache=save_timing_cache,
123141
cuda_graph_batch_size=cuda_graph_batch_size,
124142
is_aten=is_aten,
125-
use_experimental_rt=use_experimental_fx_rt,
143+
use_experimental_rt=use_experimental_rt,
144+
max_aux_streams=max_aux_streams,
145+
version_compatible=version_compatible,
146+
optimization_level=optimization_level,
126147
)
127148
lowerer = Lowerer.create(lower_setting=lower_setting)
128149
return lowerer(module, inputs)

py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import nn
55
from torch.fx.passes.pass_manager import PassManager
66

7-
from .input_tensor_spec import InputTensorSpec
7+
from torch_tensorrt.dynamo.common import InputTensorSpec
88
from torch_tensorrt.fx.passes.lower_basic_pass import (
99
fuse_permute_linear,
1010
fuse_permute_matmul,

py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
1111
from torch_tensorrt.fx.utils import LowerPrecision
1212
from torch_tensorrt import _Input
13-
from ..input_tensor_spec import InputTensorSpec
13+
from torch_tensorrt.dynamo.common import InputTensorSpec
1414

1515
from ..lower_setting import LowerSetting
1616
from torch_tensorrt.fx.observer import Observer

py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torch_tensorrt
77
from torch.testing._internal.common_utils import run_tests, TestCase
8-
from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, LowerSetting
8+
from torch_tensorrt.dynamo.common import InputTensorSpec
99

1010

1111
class TestTRTModule(TestCase):

py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.fx.passes import shape_prop
1414
from torch.fx.passes.infra.pass_base import PassResult
1515
from torch.testing._internal.common_utils import TestCase
16-
from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter
16+
from torch_tensorrt.dynamo.common import InputTensorSpec, TRTInterpreter
1717
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
1818
compose_bmm,
1919
compose_chunk,

py/torch_tensorrt/dynamo/test/test_dynamo_backend.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from transformers import BertModel
99

10-
from torch_tensorrt.dynamo.common_utils.test_utils import (
10+
from torch_tensorrt.dynamo.common.test_utils import (
1111
COSINE_THRESHOLD,
1212
cosine_similarity,
1313
)

0 commit comments

Comments
 (0)