Skip to content

Commit 4ba253f

Browse files
committed
fix: Upgrade Torch version, enable options
- Upgrade Torch version across the stack - Update Dynamo sample with advanced usage to indicate usage of new `options` argument in `torch.compile` - Enable options argument in `torch.compile` including improved input handling in the default torch_tensorrt backend - ResNet example now features `torch_tensorrt.dynamo.compile`, while transformers example features `torch_tensorrt.compile(..., ir="dynamo_compile", ...)` - Fix bugs in core runtime and `TRTInterpreter` to address issues arising with latest PyTorch distribution - Add feature in `TRTInterpreter` to specify output data types - Add `pass_through_build_failures` argument to `torch_tensorrt.dynamo.torch_compile` frontend
1 parent 6d2a26b commit 4ba253f

File tree

14 files changed

+93
-39
lines changed

14 files changed

+93
-39
lines changed

.bazelrc

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# +------------------------------------------------------------+
2323
# Enable colorful output of GCC
2424
build --cxxopt="-fdiagnostics-color=always"
25-
build --cxxopt='-std=c++14'
25+
build --cxxopt='-std=c++17'
2626
#build --linkopt="-Wl,--no-as-needed"
2727

2828

.circleci/config.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,10 @@ commands:
269269
parameters:
270270
torch-build:
271271
type: string
272-
default: "2.1.0.dev20230419+cu118"
272+
default: "2.1.0.dev20230601+cu118"
273273
torchvision-build:
274274
type: string
275-
default: "0.16.0.dev20230419+cu118"
275+
default: "0.16.0.dev20230601+cu118"
276276
torch-build-index:
277277
type: string
278278
default: "https://download.pytorch.org/whl/nightly/cu118"
@@ -1352,10 +1352,10 @@ parameters:
13521352
# Nightly platform config
13531353
torch-build:
13541354
type: string
1355-
default: "2.1.0.dev20230419+cu118"
1355+
default: "2.1.0.dev20230601+cu118"
13561356
torchvision-build:
13571357
type: string
1358-
default: "0.16.0.dev20230419+cu118"
1358+
default: "0.16.0.dev20230601+cu118"
13591359
torch-build-index:
13601360
type: string
13611361
default: "https://download.pytorch.org/whl/nightly/cu118"

CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
cmake_minimum_required(VERSION 3.17)
33
project(Torch-TensorRT LANGUAGES CXX)
44

5-
# use c++14 like PyTorch
6-
set(CMAKE_CXX_STANDARD 14)
5+
# use c++17 like PyTorch
6+
set(CMAKE_CXX_STANDARD 17)
77

88
# Build the libraries with -fPIC
99
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
116116
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
117117

118118
- Bazel 5.2.0
119-
- Libtorch 2.1.0.dev20230419 (built with CUDA 11.8)
119+
- Libtorch 2.1.0.dev20230601 (built with CUDA 11.8)
120120
- CUDA 11.8
121121
- cuDNN 8.8.0
122122
- TensorRT 8.6.1

WORKSPACE

+4-4
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,17 @@ new_local_repository(
5151
http_archive(
5252
name = "libtorch",
5353
build_file = "@//third_party/libtorch:BUILD",
54-
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
54+
sha256 = "c8407ae3462c344ae3814e82023e22ece759ebe75023f35bdf62e9c0a7e79035",
5555
strip_prefix = "libtorch",
56-
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
56+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
5757
)
5858

5959
http_archive(
6060
name = "libtorch_pre_cxx11_abi",
6161
build_file = "@//third_party/libtorch:BUILD",
62-
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
62+
sha256 = "76f983bd6d784cc0a95c679034d297abe36911c16b2188498b13a9028177e28e",
6363
strip_prefix = "libtorch",
64-
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
64+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
6565
)
6666

6767
# Download these tarballs manually from the NVIDIA website

core/runtime/TRTEngine.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ TRTEngine::TRTEngine(
111111
for (size_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) {
112112
auto binding_name = _in_binding_names[pyt_idx];
113113
auto trt_idx = cuda_engine->getBindingIndex(binding_name.c_str());
114-
std::string engine_binded_name = cuda_engine->getIOTensorName(pyt_idx);
114+
std::string engine_binded_name = cuda_engine->getIOTensorName(trt_idx);
115+
115116
TORCHTRT_CHECK(
116117
(binding_name == engine_binded_name),
117118
"Could not find a TensorRT engine binding for input named " << binding_name);

examples/dynamo/dynamo_compile_advanced_usage.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1010

1111
import torch
12-
from torch_tensorrt.dynamo.backend import create_backend
1312
from torch_tensorrt.fx.lower_setting import LowerPrecision
1413

1514
# %%
@@ -39,15 +38,19 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
3938

4039
# Next, we compile the model using torch.compile
4140
# For the default settings, we can simply call torch.compile
42-
# with the backend "tensorrt", and run the model on an
41+
# with the backend "torch_tensorrt", and run the model on an
4342
# input to cause compilation, as so:
44-
optimized_model = torch.compile(model, backend="tensorrt")
43+
optimized_model = torch.compile(model, backend="torch_tensorrt")
4544
optimized_model(*sample_inputs)
4645

4746
# %%
4847
# Compilation with `torch.compile` Using Custom Settings
4948
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
5049

50+
# First, we use Torch utilities to clean up the workspace
51+
# after the previous compile invocation
52+
torch._dynamo.reset()
53+
5154
# Define sample half inputs and initialize model
5255
sample_inputs_half = [
5356
torch.rand((5, 7)).half().cuda(),
@@ -58,20 +61,25 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
5861
# %%
5962

6063
# If we want to customize certain options in the backend,
61-
# but still use the torch.compile call directly, we can call the
62-
# convenience/helper function create_backend to create a custom backend
63-
# which has been pre-populated with certain keys
64-
custom_backend = create_backend(
65-
lower_precision=LowerPrecision.FP16,
66-
debug=True,
67-
min_block_size=2,
68-
torch_executed_ops={},
69-
optimization_level=4,
70-
use_experimental_rt=True,
71-
)
64+
# but still use the torch.compile call directly, we can provide
65+
# custom options to the backend via the "options" keyword
66+
# which takes in a dictionary mapping options to values.
67+
#
68+
# For accepted backend options, see the CompilationSettings dataclass:
69+
# py/torch_tensorrt/dynamo/backend/_settings.py
70+
backend_kwargs = {
71+
"lower_precision": LowerPrecision.FP16,
72+
"debug": True,
73+
"min_block_size": 2,
74+
"torch_executed_ops": {"torch.ops.aten.sub.Tensor"},
75+
"optimization_level": 4,
76+
"use_experimental_rt": True,
77+
}
7278

7379
# Run the model on an input to cause compilation, as so:
74-
optimized_model_custom = torch.compile(model_half, backend=custom_backend)
80+
optimized_model_custom = torch.compile(
81+
model_half, backend="torch_tensorrt", options=backend_kwargs
82+
)
7583
optimized_model_custom(*sample_inputs_half)
7684

7785
# %%

py/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ numpy
22
packaging
33
pybind11==2.6.2
44
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
5-
torch==2.1.0.dev20230419+cu118
6-
torchvision==0.16.0.dev20230419+cu118
5+
torch==2.1.0.dev20230601+cu118
6+
torchvision==0.16.0.dev20230601+cu118
77
--extra-index-url https://pypi.ngc.nvidia.com
88
tensorrt==8.6.1

py/torch_tensorrt/dynamo/backend/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def compile(
4949
min_block_size=MIN_BLOCK_SIZE,
5050
torch_executed_ops=[],
5151
torch_executed_modules=[],
52+
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
5253
max_aux_streams=MAX_AUX_STREAMS,
5354
version_compatible=VERSION_COMPATIBLE,
5455
optimization_level=OPTIMIZATION_LEVEL,
@@ -94,6 +95,7 @@ def compile(
9495
workspace_size=workspace_size,
9596
min_block_size=min_block_size,
9697
torch_executed_ops=torch_executed_ops,
98+
pass_through_build_failures=pass_through_build_failures,
9799
max_aux_streams=max_aux_streams,
98100
version_compatible=version_compatible,
99101
optimization_level=optimization_level,

py/torch_tensorrt/dynamo/backend/backends.py

+18
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Sequence
33
import torch
44
from functools import partial
5+
from dataclasses import replace, fields
56
import torch._dynamo as td
67

78
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
@@ -28,7 +29,24 @@ def torch_tensorrt_backend(
2829
gm: torch.fx.GraphModule,
2930
sample_inputs: Sequence[torch.Tensor],
3031
settings: CompilationSettings = CompilationSettings(),
32+
**kwargs
3133
):
34+
# If the user specifies keyword args, overwrite those fields in settings
35+
# Validate all specified kwargs to ensure they are true fields of the dataclass
36+
#
37+
# Note: kwargs provided by torch.compile are wrapped in the "options" key
38+
if kwargs:
39+
if "options" in kwargs and len(kwargs) == 1:
40+
kwargs = kwargs["options"]
41+
42+
valid_attrs = {attr.name for attr in fields(settings)}
43+
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
44+
settings = replace(settings, **valid_kwargs)
45+
46+
# Enable debug/verbose mode if requested
47+
if settings.debug:
48+
logger.setLevel(logging.DEBUG)
49+
3250
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3351

3452
return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)

py/torch_tensorrt/dynamo/backend/conversion.py

+10
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,21 @@ def convert_module(
2727
Returns:
2828
TRTModule or TRTModuleNext
2929
"""
30+
# Specify module output data types to ensure TRT output types agree with
31+
# that of the equivalent Torch module
32+
module_outputs = module(*inputs)
33+
34+
if not isinstance(module_outputs, (list, tuple)):
35+
module_outputs = [module_outputs]
36+
37+
output_dtypes = list(output.dtype for output in module_outputs)
38+
3039
interpreter = TRTInterpreter(
3140
module,
3241
InputTensorSpec.from_tensors(inputs),
3342
explicit_batch_dimension=True,
3443
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
44+
output_dtypes=output_dtypes,
3545
)
3646

3747
interpreter_result = interpreter.run(

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
explicit_batch_dimension: bool = True,
4242
explicit_precision: bool = False,
4343
logger_level=None,
44+
output_dtypes=None,
4445
):
4546
super().__init__(module)
4647

@@ -79,6 +80,9 @@ def __init__(
7980
trt.tensorrt.ITensor, TensorMetadata
8081
] = dict()
8182

83+
# Data types for TRT Module output Tensors
84+
self.output_dtypes = output_dtypes
85+
8286
def validate_input_specs(self):
8387
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
8488
if not self.network.has_implicit_batch_dimension:
@@ -179,13 +183,17 @@ def run(
179183
algorithm_selector: set up algorithm selection for certain layer
180184
timing_cache: enable timing cache for TensorRT
181185
profiling_verbosity: TensorRT logging level
186+
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
187+
version_compatible: Provide version forward-compatibility for engine plan files
188+
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
189+
searching for more optimization options. TRT defaults to 3
182190
Return:
183191
TRTInterpreterResult
184192
"""
185193
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
186194

187195
# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
188-
# force_fp32_output=False.
196+
# force_fp32_output=False. Overriden by specifying output_dtypes
189197
self.output_fp16 = (
190198
not force_fp32_output and lower_precision == LowerPrecision.FP16
191199
)
@@ -373,6 +381,11 @@ def output(self, target, args, kwargs):
373381
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
374382
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
375383

384+
if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs):
385+
raise RuntimeError(
386+
f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})"
387+
)
388+
376389
for i, output in enumerate(outputs):
377390
if any(
378391
op_name in output.name.split("_")
@@ -397,6 +410,8 @@ def output(self, target, args, kwargs):
397410
self.network.mark_output(output)
398411
if output_bool:
399412
output.dtype = trt.bool
413+
elif self.output_dtypes is not None:
414+
output.dtype = torch_dtype_to_trt(self.output_dtypes[i])
400415
elif self.output_fp16 and output.dtype == trt.float32:
401416
output.dtype = trt.float16
402417
self._output_names.append(name)

toolchains/ci_workspaces/WORKSPACE.x86_64.release.rhel

+4-4
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ new_local_repository(
5656
http_archive(
5757
name = "libtorch",
5858
build_file = "@//third_party/libtorch:BUILD",
59-
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
59+
sha256 = "c8407ae3462c344ae3814e82023e22ece759ebe75023f35bdf62e9c0a7e79035",
6060
strip_prefix = "libtorch",
61-
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
61+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
6262
)
6363

6464
http_archive(
6565
name = "libtorch_pre_cxx11_abi",
6666
build_file = "@//third_party/libtorch:BUILD",
67-
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
67+
sha256 = "76f983bd6d784cc0a95c679034d297abe36911c16b2188498b13a9028177e28e",
6868
strip_prefix = "libtorch",
69-
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
69+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
7070
)
7171

7272
####################################################################################

toolchains/ci_workspaces/WORKSPACE.x86_64.release.ubuntu

+4-4
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,17 @@ new_local_repository(
5656
http_archive(
5757
name = "libtorch",
5858
build_file = "@//third_party/libtorch:BUILD",
59-
sha256 = "1a526a9cd19c1015674d26921dbb94bcd2d632a6f9c431a21c43f4e24768d834",
59+
sha256 = "c8407ae3462c344ae3814e82023e22ece759ebe75023f35bdf62e9c0a7e79035",
6060
strip_prefix = "libtorch",
61-
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
61+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
6262
)
6363

6464
http_archive(
6565
name = "libtorch_pre_cxx11_abi",
6666
build_file = "@//third_party/libtorch:BUILD",
67-
sha256 = "60c5912a5085a6a7073b3804b10d41d6cc054693bbeb7a45e0247050c2837bac",
67+
sha256 = "76f983bd6d784cc0a95c679034d297abe36911c16b2188498b13a9028177e28e",
6868
strip_prefix = "libtorch",
69-
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230419%2Bcu118.zip"],
69+
urls = ["https://download.pytorch.org/libtorch/nightly/cu118/libtorch-shared-with-deps-2.1.0.dev20230601%2Bcu118.zip"],
7070
)
7171

7272
####################################################################################

0 commit comments

Comments
 (0)