Skip to content

Commit 2a0d1c8

Browse files
committed
refactor(//py)!: Kwargs updates and support for shifting internal apis
BREAKING CHANGE: This commit changes the APIs from a dictionary of arguements to a set of kwargs. You can port forward using ```py trtorch.compile(mod, **spec) ``` Also in preparation for partial compilation to be enabled by default settings related to torch fallback have been moved to the top level instead of ```py "torch_fallback": { "enabled": True, "min_block_size" " 3, "forced_fallback_ops" : ["aten::add"], "forced_fallback_mods" : ["MySubModule"] } ``` now there are new settings ```py require_full_compilation=False, min_block_size=3, torch_executed_ops=["aten::add"], torch_executed_modules=["MySubModule"] ``` Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 748ecf3 commit 2a0d1c8

File tree

8 files changed

+140
-109
lines changed

8 files changed

+140
-109
lines changed

docsrc/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ sphinx==3.1.2
22
breathe==4.19.2
33
exhale
44
sphinx_rtd_theme==0.4.3
5-
sphinx-material==0.0.30
5+
sphinx-material==0.0.35
66
nbsphinx==0.8.6

py/trtorch/Device.py

+5
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def _from_torch_device(cls, torch_dev: torch.device):
105105
gpu_id = torch_dev.index
106106
return cls(gpu_id=gpu_id)
107107

108+
@classmethod
109+
def _current_device(cls):
110+
dev = trtorch._C._get_current_device()
111+
return cls(gpu_id=dev.gpu_id)
112+
108113
@staticmethod
109114
def _parse_device_str(s):
110115
s = s.lower()

py/trtorch/_compile_spec.py

+71-53
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from trtorch import _types
55
from trtorch.Input import Input
66
from trtorch.Device import Device
7+
from trtorch._types import EngineCapability
78

89
import warnings
910

@@ -246,63 +247,80 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
246247
return info
247248

248249

249-
def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.CompileSpec:
250-
"""
251-
Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
252-
253-
Args:
254-
compile_spec (dict): Compilation settings including operating precision, target device, etc.
255-
One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs
256-
to the graph as well as expect types and formats for those inputs. All other keys are optional.
257-
Entries for each method to be compiled.
258-
259-
Note: Partial compilation of TorchScript modules is not supported through the PyTorch TensorRT backend
260-
If you need this feature, use trtorch.compile to compile your module. Usage of the resulting module is
261-
as if you were using the TensorRT integration.
262-
263-
.. code-block:: py
264-
265-
CompileSpec = {
266-
"forward" : trtorch.TensorRTCompileSpec({
267-
"inputs": [
268-
trtorch.Input((1, 3, 224, 224)), # Static input shape for input #1
269-
trtorch.Input(
270-
min_shape=1, 3, 224, 224),
271-
opt_shape=(1, 3, 512, 512),
272-
max_shape=(1, 3, 1024, 1024),
273-
dtype=torch.int32
274-
format=torch.channel_last
275-
) # Dynamic input shape for input #2
276-
],
277-
"device": {
278-
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
279-
"gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
280-
"dla_core": 0, # (DLA only) Target dla core id to run engine
281-
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
282-
},
283-
"enabled_precisions": {torch.half}, # Operating precision set to FP16
284-
"sparse_weights": Enable sparsity for convolution and fully connected layers.
285-
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
286-
"refit": False, # enable refit
287-
"debug": False, # enable debuggable engine
288-
"strict_types": False, # kernels should strictly run in operating precision
289-
"capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
290-
"num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
291-
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
292-
"workspace_size": 0, # Maximum size of workspace given to TensorRT
293-
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
294-
"truncate_long_and_double": False, # Truncate long and double into int and float
295-
})
296-
}
297-
298-
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
250+
def TensorRTCompileSpec(inputs=[],
251+
device=Device._current_device(),
252+
disable_tf32=False,
253+
sparse_weights=False,
254+
enabled_precisions=set(),
255+
refit=False,
256+
debug=False,
257+
strict_types=False,
258+
capability=EngineCapability.default,
259+
num_min_timing_iters=2,
260+
num_avg_timing_iters=1,
261+
workspace_size=0,
262+
max_batch_size=0,
263+
truncate_long_and_double=False,
264+
calibrator=None) -> torch.classes.tensorrt.CompileSpec:
265+
"""Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
266+
267+
Keyword Args:
268+
inputs (List[Union(trtorch.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
299269
torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
300-
to select device type.
301-
302-
Returns:
270+
to select device type. ::
271+
272+
input=[
273+
trtorch.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
274+
trtorch.Input(
275+
min_shape=(1, 224, 224, 3),
276+
opt_shape=(1, 512, 512, 3),
277+
max_shape=(1, 1024, 1024, 3),
278+
dtype=torch.int32
279+
format=torch.channel_last
280+
), # Dynamic input shape for input #2
281+
torch.randn((1, 3, 224, 244)) # Use an example tensor and let trtorch infer settings
282+
]
283+
284+
device (Union(trtorch.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
285+
286+
device=trtorch.Device("dla:1", allow_gpu_fallback=True)
287+
288+
disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
289+
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
290+
enabled_precision (Set(Union(torch.dtype, trtorch.dtype))): The set of datatypes that TensorRT can use when selecting kernels
291+
refit (bool): Enable refitting
292+
debug (bool): Enable debuggable engine
293+
strict_types (bool): Kernels should strictly run in a particular operating precision. Enabled precision should only have one type in the set
294+
capability (trtorch.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
295+
num_min_timing_iters (int): Number of minimization timing iterations used to select kernels
296+
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
297+
workspace_size (int): Maximum size of workspace given to TensorRT
298+
max_batch_size (int): Maximum batch size (must be >= 1 to be set, 0 means not set)
299+
truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32
300+
calibrator (Union(trtorch._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
301+
302+
Returns:
303303
torch.classes.tensorrt.CompileSpec: List of methods and formated spec objects to be provided to ``torch._C._jit_to_tensorrt``
304304
"""
305305

306+
compile_spec = {
307+
"inputs": inputs,
308+
"device": device,
309+
"disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
310+
"sparse_weights": sparse_weights, #Enable sparsity for convolution and fully connected layers.
311+
"enabled_precisions": enabled_precisions, # Enabling FP16 kernels
312+
"refit": refit, # enable refit
313+
"debug": debug, # enable debuggable engine
314+
"strict_types": strict_types, # kernels should strictly run in operating precision
315+
"capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels
316+
"num_min_timing_iters": num_min_timing_iters, # Number of minimization timing iterations used to select kernels
317+
"num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels
318+
"workspace_size": workspace_size, # Maximum size of workspace given to TensorRT
319+
"max_batch_size": max_batch_size, # Maximum batch size (must be >= 1 to be set, 0 means not set)
320+
"calibrator": calibrator,
321+
"truncate_long_and_double": truncate_long_and_double
322+
}
323+
306324
parsed_spec = _parse_compile_spec(compile_spec)
307325

308326
backend_spec = torch.classes.tensorrt.CompileSpec()

0 commit comments

Comments
 (0)