Skip to content

Commit c0f0a22

Browse files
committed
feat: Automatically detect C++ dependency presence
- Default automatically to test for presence of C++ dependency and use appropriate runtime if not specified by the user - Improve test timing for TRT Backend testing, which now utilizes C++ runtime
1 parent f2f40a2 commit c0f0a22

File tree

12 files changed

+107
-54
lines changed

12 files changed

+107
-54
lines changed

py/torch_tensorrt/dynamo/_TorchTensorRTModule.py

-3
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ def __init__(
6969
)
7070
7171
"""
72-
logger.warning(
73-
"TorchTensorRTModule should be considered experimental stability, APIs are subject to change. Note: TorchTensorRTModule only supports engines built with explict batch"
74-
)
7572
super(TorchTensorRTModule, self).__init__()
7673

7774
if not isinstance(serialized_engine, bytearray):

py/torch_tensorrt/dynamo/backend/__init__.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
MAX_AUX_STREAMS,
2020
VERSION_COMPATIBLE,
2121
OPTIMIZATION_LEVEL,
22-
USE_EXPERIMENTAL_RT,
22+
USE_PYTHON_RUNTIME,
2323
)
2424

2525

@@ -52,7 +52,7 @@ def compile(
5252
max_aux_streams=MAX_AUX_STREAMS,
5353
version_compatible=VERSION_COMPATIBLE,
5454
optimization_level=OPTIMIZATION_LEVEL,
55-
use_experimental_rt=USE_EXPERIMENTAL_RT,
55+
use_python_runtime=USE_PYTHON_RUNTIME,
5656
**kwargs,
5757
):
5858
if debug:
@@ -65,11 +65,6 @@ def compile(
6565
+ "torch_executed_ops, pass_through_build_failures}"
6666
)
6767

68-
if "use_experimental_fx_rt" in kwargs:
69-
use_experimental_rt = kwargs["use_experimental_fx_rt"]
70-
71-
logger.info(f"Using {'C++' if use_experimental_rt else 'Python'} TRT Runtime")
72-
7368
if not isinstance(inputs, collections.abc.Sequence):
7469
inputs = [inputs]
7570

@@ -107,7 +102,7 @@ def compile(
107102
max_aux_streams=max_aux_streams,
108103
version_compatible=version_compatible,
109104
optimization_level=optimization_level,
110-
use_experimental_rt=use_experimental_rt,
105+
use_python_runtime=use_python_runtime,
111106
**kwargs,
112107
)
113108

@@ -134,7 +129,7 @@ def create_backend(
134129
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
135130
version_compatible: bool = VERSION_COMPATIBLE,
136131
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
137-
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
132+
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
138133
**kwargs,
139134
):
140135
"""Create torch.compile backend given specified arguments
@@ -150,7 +145,9 @@ def create_backend(
150145
version_compatible: Provide version forward-compatibility for engine plan files
151146
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
152147
searching for more optimization options. TRT defaults to 3
153-
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
148+
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
149+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
150+
argument as None
154151
Returns:
155152
Backend for torch.compile
156153
"""
@@ -165,5 +162,6 @@ def create_backend(
165162
max_aux_streams=max_aux_streams,
166163
version_compatible=version_compatible,
167164
optimization_level=optimization_level,
168-
use_experimental_rt=use_experimental_rt,
165+
use_python_runtime=use_python_runtime,
166+
**kwargs,
169167
)

py/torch_tensorrt/dynamo/backend/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
MAX_AUX_STREAMS = None
1010
VERSION_COMPATIBLE = False
1111
OPTIMIZATION_LEVEL = None
12-
USE_EXPERIMENTAL_RT = False
12+
USE_PYTHON_RUNTIME = None

py/torch_tensorrt/dynamo/backend/_settings.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
MAX_AUX_STREAMS,
1212
VERSION_COMPATIBLE,
1313
OPTIMIZATION_LEVEL,
14-
USE_EXPERIMENTAL_RT,
14+
USE_PYTHON_RUNTIME,
1515
)
1616

1717

18-
@dataclass(frozen=True)
18+
@dataclass
1919
class CompilationSettings:
2020
precision: LowerPrecision = PRECISION
2121
debug: bool = DEBUG
@@ -26,4 +26,4 @@ class CompilationSettings:
2626
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
2727
version_compatible: bool = VERSION_COMPATIBLE
2828
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
29-
use_experimental_rt: bool = USE_EXPERIMENTAL_RT
29+
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME

py/torch_tensorrt/dynamo/backend/conversion.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@ def convert_module(
5555
optimization_level=settings.optimization_level,
5656
)
5757

58-
if settings.use_experimental_rt:
58+
if settings.use_python_runtime:
59+
return TRTModule(
60+
engine=interpreter_result.engine,
61+
input_names=interpreter_result.input_names,
62+
output_names=interpreter_result.output_names,
63+
)
64+
65+
else:
5966
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule
6067

6168
with io.BytesIO() as engine_bytes:
@@ -67,9 +74,3 @@ def convert_module(
6774
input_binding_names=interpreter_result.input_names,
6875
output_binding_names=interpreter_result.output_names,
6976
)
70-
else:
71-
return TRTModule(
72-
engine=interpreter_result.engine,
73-
input_names=interpreter_result.input_names,
74-
output_names=interpreter_result.output_names,
75-
)

py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def forward(self, x, y):
4040
min_block_size=1,
4141
pass_through_build_failures=True,
4242
torch_executed_ops={"torch.ops.aten.add.Tensor"},
43-
use_experimental_rt=True,
43+
use_python_runtime=False,
4444
debug=True,
4545
)
4646
optimized_model_results = optimized_model(*inputs).detach().cpu()
@@ -108,7 +108,7 @@ def forward(self, x, y):
108108
min_block_size=1,
109109
pass_through_build_failures=True,
110110
torch_executed_ops={"torch.ops.aten.add.Tensor"},
111-
use_experimental_rt=True,
111+
use_python_runtime=False,
112112
debug=True,
113113
)
114114
optimized_model_results = optimized_model(*inputs).detach().cpu()
@@ -149,7 +149,7 @@ def forward(self, x, y):
149149
inputs,
150150
min_block_size=1,
151151
pass_through_build_failures=True,
152-
use_experimental_rt=True,
152+
use_python_runtime=False,
153153
optimization_level=4,
154154
version_compatible=True,
155155
max_aux_streams=5,

py/torch_tensorrt/dynamo/backend/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
66
from typing import Any, Union, Sequence, Dict
77
from torch_tensorrt import _Input, Device
8+
from ..common_utils import use_python_runtime_parser
89

910

1011
logger = logging.getLogger(__name__)
@@ -102,6 +103,9 @@ def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings:
102103
if settings.debug:
103104
logger.setLevel(logging.DEBUG)
104105

106+
# Parse input runtime specification
107+
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
108+
105109
logger.debug(f"Compiling with Settings:\n{settings}")
106110

107111
return settings
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import logging
2+
from typing import Optional
3+
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool:
9+
"""Parses a user-provided input argument regarding Python runtime
10+
11+
Automatically handles cases where the user has not specified a runtime (None)
12+
13+
Returns True if the Python runtime should be used, False if the C++ runtime should be used
14+
"""
15+
using_python_runtime = use_python_runtime
16+
reason = ""
17+
18+
# Runtime was manually specified by the user
19+
if using_python_runtime is not None:
20+
reason = "as requested by user"
21+
# Runtime was not manually specified by the user, automatically detect runtime
22+
else:
23+
try:
24+
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule
25+
26+
using_python_runtime = False
27+
reason = "since C++ dependency was detected as present"
28+
except ImportError:
29+
using_python_runtime = True
30+
reason = "since import failed, C++ dependency not installed"
31+
32+
logger.info(
33+
f"Using {'Python' if using_python_runtime else 'C++'} {reason} TRT Runtime"
34+
)
35+
36+
return using_python_runtime

py/torch_tensorrt/dynamo/fx_ts_compat/lower.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .lower_setting import LowerSetting
1515
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
1616
from .passes.pass_utils import PassFunc, validate_inference
17+
from ..common_utils import use_python_runtime_parser
1718
from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager
1819
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
1920

@@ -48,7 +49,7 @@ def compile(
4849
save_timing_cache=False,
4950
cuda_graph_batch_size=-1,
5051
is_aten=False,
51-
use_experimental_fx_rt=False,
52+
use_python_runtime=None,
5253
max_aux_streams=None,
5354
version_compatible=False,
5455
optimization_level=None,
@@ -70,7 +71,9 @@ def compile(
7071
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
7172
save_timing_cache: Update timing cache with current timing cache data if set to True.
7273
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
73-
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74+
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
75+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
76+
argument as None
7477
max_aux_streams: max number of aux stream to use
7578
version_compatible: enable version compatible feature
7679
optimization_level: builder optimization level
@@ -111,6 +114,9 @@ def compile(
111114
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
112115
)
113116

117+
# Parse user-specification of which runtime to use
118+
use_python_runtime = use_python_runtime_parser(use_python_runtime)
119+
114120
lower_setting = LowerSetting(
115121
device=device,
116122
min_block_size=min_block_size,
@@ -123,7 +129,7 @@ def compile(
123129
save_timing_cache=save_timing_cache,
124130
cuda_graph_batch_size=cuda_graph_batch_size,
125131
is_aten=is_aten,
126-
use_experimental_rt=use_experimental_fx_rt,
132+
use_python_runtime=use_python_runtime,
127133
max_aux_streams=max_aux_streams,
128134
version_compatible=version_compatible,
129135
optimization_level=optimization_level,
@@ -202,7 +208,7 @@ def default_split_function(
202208
splitter_setting = TRTSplitterSetting()
203209
splitter_setting.use_implicit_batch_dim = False
204210
splitter_setting.min_block_size = lower_setting.min_block_size
205-
splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt
211+
splitter_setting.use_experimental_rt = not lower_setting.use_python_runtime
206212
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
207213
splitter.node_support_preview()
208214
return splitter.generate_split_results()
@@ -224,9 +230,17 @@ def lower_pass(
224230
"""
225231
interpreter = create_trt_interpreter(lower_setting)
226232
interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
227-
if lower_setting.use_experimental_rt:
228-
import io
233+
if lower_setting.use_python_runtime:
234+
trt_module = TRTModule(
235+
engine=interp_res.engine,
236+
input_names=interp_res.input_names,
237+
output_names=interp_res.output_names,
238+
cuda_graph_batch_size=lower_setting.cuda_graph_batch_size,
239+
)
240+
return trt_module
229241

242+
else:
243+
import io
230244
from torch_tensorrt._Device import Device
231245
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule
232246

@@ -240,16 +254,6 @@ def lower_pass(
240254
input_binding_names=interp_res.input_names,
241255
output_binding_names=interp_res.output_names,
242256
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
243-
# cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do
244-
)
245-
return trt_module
246-
247-
else:
248-
trt_module = TRTModule(
249-
engine=interp_res.engine,
250-
input_names=interp_res.input_names,
251-
output_names=interp_res.output_names,
252-
cuda_graph_batch_size=lower_setting.cuda_graph_batch_size,
253257
)
254258
return trt_module
255259

py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ class LowerSetting(LowerSettingBasic):
6868
meaning all possible tactic sources.
6969
correctness_atol: absolute tolerance for correctness check
7070
correctness_rtol: relative tolerance for correctness check
71-
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
71+
use_python_runtime: Whether to use Python runtime or C++ runtime. None implies the user has not
72+
selected a runtime, and the frontend will automatically do so on their behalf
7273
max_aux_streams: max number of aux stream to use
7374
version_compatible: enable version compatible feature
7475
optimization_level: builder optimization level
@@ -95,7 +96,7 @@ class LowerSetting(LowerSettingBasic):
9596
tactic_sources: Optional[int] = None
9697
correctness_atol: float = 0.1
9798
correctness_rtol: float = 0.1
98-
use_experimental_rt: bool = False
99+
use_python_runtime: Optional[bool] = None
99100
max_aux_streams: Optional[int] = None
100101
version_compatible: bool = False
101102
optimization_level: Optional[int] = None

py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
def lower_mod_default(
1616
mod: torch.fx.GraphModule,
1717
inputs: Tensors,
18-
use_experimental_rt: bool = False,
18+
use_python_runtime: bool = False,
1919
) -> TRTModule:
2020
interp = TRTInterpreter(
2121
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
2222
)
2323
interpreter_result = interp.run()
24-
if use_experimental_rt:
24+
if use_python_runtime:
25+
res_mod = TRTModule(
26+
interpreter_result.engine,
27+
interpreter_result.input_names,
28+
interpreter_result.output_names,
29+
)
30+
31+
else:
2532
import io
2633

2734
from torch_tensorrt._Device import Device
@@ -39,12 +46,7 @@ def lower_mod_default(
3946
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
4047
# cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do
4148
)
42-
else:
43-
res_mod = TRTModule(
44-
interpreter_result.engine,
45-
interpreter_result.input_names,
46-
interpreter_result.output_names,
47-
)
49+
4850
return res_mod
4951

5052

0 commit comments

Comments
 (0)