14
14
from .lower_setting import LowerSetting
15
15
from .passes .lower_pass_manager_builder import LowerPassManagerBuilder
16
16
from .passes .pass_utils import PassFunc , validate_inference
17
+ from ..common_utils import use_python_runtime_parser
17
18
from torch_tensorrt .fx .tools .timing_cache_utils import TimingCacheManager
18
19
from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter , TRTSplitterSetting
19
20
@@ -48,7 +49,7 @@ def compile(
48
49
save_timing_cache = False ,
49
50
cuda_graph_batch_size = - 1 ,
50
51
is_aten = False ,
51
- use_experimental_fx_rt = False ,
52
+ use_python_runtime = None ,
52
53
max_aux_streams = None ,
53
54
version_compatible = False ,
54
55
optimization_level = None ,
@@ -70,7 +71,9 @@ def compile(
70
71
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
71
72
save_timing_cache: Update timing cache with current timing cache data if set to True.
72
73
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
73
- use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74
+ use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
75
+ based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
76
+ argument as None
74
77
max_aux_streams: max number of aux stream to use
75
78
version_compatible: enable version compatible feature
76
79
optimization_level: builder optimization level
@@ -111,6 +114,9 @@ def compile(
111
114
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
112
115
)
113
116
117
+ # Parse user-specification of which runtime to use
118
+ use_python_runtime = use_python_runtime_parser (use_python_runtime )
119
+
114
120
lower_setting = LowerSetting (
115
121
device = device ,
116
122
min_block_size = min_block_size ,
@@ -123,7 +129,7 @@ def compile(
123
129
save_timing_cache = save_timing_cache ,
124
130
cuda_graph_batch_size = cuda_graph_batch_size ,
125
131
is_aten = is_aten ,
126
- use_experimental_rt = use_experimental_fx_rt ,
132
+ use_python_runtime = use_python_runtime ,
127
133
max_aux_streams = max_aux_streams ,
128
134
version_compatible = version_compatible ,
129
135
optimization_level = optimization_level ,
@@ -202,7 +208,7 @@ def default_split_function(
202
208
splitter_setting = TRTSplitterSetting ()
203
209
splitter_setting .use_implicit_batch_dim = False
204
210
splitter_setting .min_block_size = lower_setting .min_block_size
205
- splitter_setting .use_experimental_rt = lower_setting .use_experimental_rt
211
+ splitter_setting .use_experimental_rt = not lower_setting .use_python_runtime
206
212
splitter = TRTSplitter (model , inputs , settings = splitter_setting )
207
213
splitter .node_support_preview ()
208
214
return splitter .generate_split_results ()
@@ -224,9 +230,17 @@ def lower_pass(
224
230
"""
225
231
interpreter = create_trt_interpreter (lower_setting )
226
232
interp_res : TRTInterpreterResult = interpreter (mod , input , module_name )
227
- if lower_setting .use_experimental_rt :
228
- import io
233
+ if lower_setting .use_python_runtime :
234
+ trt_module = TRTModule (
235
+ engine = interp_res .engine ,
236
+ input_names = interp_res .input_names ,
237
+ output_names = interp_res .output_names ,
238
+ cuda_graph_batch_size = lower_setting .cuda_graph_batch_size ,
239
+ )
240
+ return trt_module
229
241
242
+ else :
243
+ import io
230
244
from torch_tensorrt ._Device import Device
231
245
from torch_tensorrt .dynamo ._TorchTensorRTModule import TorchTensorRTModule
232
246
@@ -240,16 +254,6 @@ def lower_pass(
240
254
input_binding_names = interp_res .input_names ,
241
255
output_binding_names = interp_res .output_names ,
242
256
target_device = Device (f"cuda:{ torch .cuda .current_device ()} " ),
243
- # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do
244
- )
245
- return trt_module
246
-
247
- else :
248
- trt_module = TRTModule (
249
- engine = interp_res .engine ,
250
- input_names = interp_res .input_names ,
251
- output_names = interp_res .output_names ,
252
- cuda_graph_batch_size = lower_setting .cuda_graph_batch_size ,
253
257
)
254
258
return trt_module
255
259
0 commit comments