10
10
import torch_tensorrt .fx .tracer .dispatch_tracer .aten_tracer as aten_tracer
11
11
from torch .fx .passes .splitter_base import SplitResult
12
12
13
- from . fx2trt import TRTInterpreter , TRTInterpreterResult
13
+ from torch_tensorrt . dynamo . common import TRTInterpreter , TRTInterpreterResult
14
14
from .lower_setting import LowerSetting
15
15
from .passes .lower_pass_manager_builder import LowerPassManagerBuilder
16
16
from .passes .pass_utils import PassFunc , validate_inference
21
21
from torch_tensorrt .fx .trt_module import TRTModule
22
22
from torch_tensorrt .fx .utils import LowerPrecision
23
23
from torch_tensorrt ._Device import Device
24
+ from torch_tensorrt .dynamo .common ._defaults import (
25
+ PRECISION ,
26
+ DEBUG ,
27
+ WORKSPACE_SIZE ,
28
+ MIN_BLOCK_SIZE ,
29
+ PASS_THROUGH_BUILD_FAILURES ,
30
+ MAX_AUX_STREAMS ,
31
+ VERSION_COMPATIBLE ,
32
+ OPTIMIZATION_LEVEL ,
33
+ USE_EXPERIMENTAL_RT ,
34
+ )
24
35
25
36
logger = logging .getLogger (__name__ )
26
37
@@ -34,21 +45,25 @@ def compile(
34
45
disable_tf32 = False ,
35
46
sparse_weights = False ,
36
47
enabled_precisions = set (),
37
- min_block_size : int = 3 ,
38
- workspace_size = 0 ,
48
+ min_block_size : int = MIN_BLOCK_SIZE ,
49
+ workspace_size = WORKSPACE_SIZE ,
39
50
dla_sram_size = 1048576 ,
40
51
dla_local_dram_size = 1073741824 ,
41
52
dla_global_dram_size = 536870912 ,
42
53
calibrator = None ,
43
54
truncate_long_and_double = False ,
44
55
require_full_compilation = False ,
45
- debug = False ,
56
+ explicit_batch_dimension = False ,
57
+ debug = DEBUG ,
46
58
refit = False ,
47
59
timing_cache_prefix = "" ,
48
60
save_timing_cache = False ,
49
61
cuda_graph_batch_size = - 1 ,
50
62
is_aten = False ,
51
- use_experimental_fx_rt = False ,
63
+ use_experimental_rt = USE_EXPERIMENTAL_RT ,
64
+ max_aux_streams = MAX_AUX_STREAMS ,
65
+ version_compatible = VERSION_COMPATIBLE ,
66
+ optimization_level = OPTIMIZATION_LEVEL ,
52
67
num_avg_timing_iters = 1 ,
53
68
torch_executed_ops = [],
54
69
torch_executed_modules = [],
@@ -67,11 +82,14 @@ def compile(
67
82
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
68
83
save_timing_cache: Update timing cache with current timing cache data if set to True.
69
84
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
70
- use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
85
+ use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
86
+ max_aux_streams: max number of aux stream to use
87
+ version_compatible: enable version compatible feature
88
+ optimization_level: builder optimization level
71
89
Returns:
72
90
A torch.nn.Module lowered by TensorRT.
73
91
"""
74
- if use_experimental_fx_rt and not explicit_batch_dimension :
92
+ if use_experimental_rt and not explicit_batch_dimension :
75
93
raise ValueError (
76
94
"The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True"
77
95
)
@@ -122,7 +140,10 @@ def compile(
122
140
save_timing_cache = save_timing_cache ,
123
141
cuda_graph_batch_size = cuda_graph_batch_size ,
124
142
is_aten = is_aten ,
125
- use_experimental_rt = use_experimental_fx_rt ,
143
+ use_experimental_rt = use_experimental_rt ,
144
+ max_aux_streams = max_aux_streams ,
145
+ version_compatible = version_compatible ,
146
+ optimization_level = optimization_level ,
126
147
)
127
148
lowerer = Lowerer .create (lower_setting = lower_setting )
128
149
return lowerer (module , inputs )
0 commit comments