Skip to content

Commit fcbf9e7

Browse files
committed
chore: rebase with main
2 parents 46cc402 + 574368c commit fcbf9e7

File tree

13 files changed

+394
-331
lines changed

13 files changed

+394
-331
lines changed

core/runtime/TRTEngine.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ std::string TRTEngine::to_str() const {
266266
exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
267267
<< std::endl;
268268
}
269-
ss << " }" << std::endl;
269+
ss << " ]" << std::endl;
270270
ss << " Device: " << device_info << std::endl;
271271
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
272272
// clang-format on

docsrc/py_api/dynamo.rst

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Functions
2222

2323
.. autofunction:: export
2424

25+
.. autofunction:: convert_module_to_trt_engine
26+
2527

2628

2729
Classes

docsrc/user_guide/saving_models.rst

+35-36
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,22 @@ Saving models compiled with Torch-TensorRT
99
:undoc-members:
1010
:show-inheritance:
1111

12-
Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.
12+
Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.save` API.
1313

1414
Dynamo IR
1515
-------------
1616

17-
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
18-
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
19-
The `output_format` can take the following options
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18+
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
19+
specifying the `output_format` flag. Here are the options `output_format` will accept
2020

21-
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
22-
* `torchscript` (or) `ts` : This returns a TorchScript module
23-
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
21+
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
22+
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.
2423

25-
a) Torchscript
24+
a) ExportedProgram
2625
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2726

28-
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
27+
Here's an example usage
2928

3029
.. code-block:: python
3130
@@ -34,50 +33,32 @@ If you set the `output_format="torchscript"`, this will return a `ScriptModule`
3433
3534
model = MyModel().eval().cuda()
3635
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
37-
# trt_ts is a torch.jit.ScriptModule object
38-
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
39-
torch.jit.save(trt_ts, "trt_model.ts")
36+
# trt_ep is a torch.fx.GraphModule object
37+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
38+
torchtrt.save(trt_gm, "trt.ep", inputs=inputs)
4039
4140
# Later, you can load it and run inference
42-
model = torch.jit.load("trt_model.ts").cuda()
41+
model = torch.export.load("trt.ep").module()
4342
model(*inputs)
4443
45-
b) ExportedProgram
44+
b) Torchscript
4645
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4746

48-
`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.
49-
5047
.. code-block:: python
5148
5249
import torch
5350
import torch_tensorrt
5451
5552
model = MyModel().eval().cuda()
5653
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
57-
# trt_ep is a torch.export.ExportedProgram object
58-
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
59-
torch.export.save(trt_ep, "trt_model.ep")
54+
# trt_gm is a torch.fx.GraphModule object
55+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
56+
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
6057
6158
# Later, you can load it and run inference
62-
model = torch.export.load("trt_model.ep")
59+
model = torch.jit.load("trt.ts").cuda()
6360
model(*inputs)
6461
65-
c) GraphModule
66-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
67-
68-
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
69-
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
70-
exported into `ExportedProgram` objects
71-
72-
.. code-block:: python
73-
74-
import torch
75-
import torch_tensorrt
76-
77-
model = MyModel().eval().cuda()
78-
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
79-
# trt_gm is a torch.fx.GraphModule object
80-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")
8162
8263
Torchscript IR
8364
-------------
@@ -99,3 +80,21 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
9980
model = torch.jit.load("trt_model.ts").cuda()
10081
model(*inputs)
10182
83+
84+
Loading the models
85+
--------------------
86+
87+
We can load torchscript or exported_program models using `torch.jit.load` and `torch.export.load` APIs from PyTorch directly.
88+
Alternatively, we provide a light wrapper `torch_tensorrt.load(file_path)` which can load either of the above model types.
89+
90+
Here's an example usage
91+
92+
.. code-block:: python
93+
94+
import torch
95+
import torch_tensorrt
96+
97+
# file_path can be trt.ep or trt.ts file obtained via saving the model (refer to the above section)
98+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
99+
model = torch_tensorrt.load(<file_path>).module()
100+
model(*inputs)

py/torch_tensorrt/_compile.py

+118-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import collections.abc
34
import logging
45
from enum import Enum
56
from typing import Any, Callable, List, Optional, Sequence, Set
@@ -32,10 +33,7 @@
3233

3334
logger = logging.getLogger(__name__)
3435

35-
__all__ = [
36-
"compile",
37-
"convert_method_to_trt_engine",
38-
]
36+
__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"]
3937

4038

4139
def _non_fx_input_interface(
@@ -240,8 +238,6 @@ def compile(
240238
return compiled_fx_module
241239
elif target_ir == _IRType.dynamo:
242240
# Prepare torch and torchtrt inputs
243-
import collections.abc
244-
245241
from torch_tensorrt.dynamo.utils import prepare_inputs
246242

247243
if not isinstance(input_list, collections.abc.Sequence):
@@ -345,10 +341,19 @@ def convert_method_to_trt_engine(
345341
"convert_method_to_trt_engine call is not supported for ir=fx"
346342
)
347343
elif target_ir == _IRType.dynamo:
344+
# Prepare torch and torchtrt inputs
345+
from torch_tensorrt.dynamo.utils import prepare_inputs
346+
347+
if not isinstance(inputs, collections.abc.Sequence):
348+
inputs = [inputs]
349+
350+
# Export the module
351+
torchtrt_inputs = prepare_inputs(inputs)
352+
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
353+
348354
return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
349-
module,
355+
exp_program,
350356
inputs=inputs,
351-
method_name=method_name,
352357
enabled_precisions=enabled_precisions_set,
353358
**kwargs,
354359
)
@@ -358,3 +363,108 @@ def convert_method_to_trt_engine(
358363
)
359364
else:
360365
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
366+
367+
368+
def load(file_path: str = "") -> Any:
369+
"""
370+
Load either a Torchscript model or ExportedProgram. Autodetect the type using
371+
try, except
372+
"""
373+
try:
374+
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
375+
ts_module = torch.jit.load(file_path)
376+
return ts_module
377+
except Exception:
378+
logger.info(
379+
f"Loading the provided file {file_path} via torch.jit.load() failed with the following error",
380+
exc_info=True,
381+
)
382+
pass
383+
384+
try:
385+
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
386+
exp_program = torch.export.load(file_path)
387+
return exp_program
388+
except Exception:
389+
logger.info(
390+
f"Loading the provided file {file_path} via torch.export.load() failed with the following error",
391+
exc_info=True,
392+
)
393+
raise ValueError(
394+
f"The file {file_path} doesn't correspond to a valid Torchscript module or ExportedProgram. Please verify the file path."
395+
)
396+
397+
398+
def save(
399+
module: Any,
400+
file_path: str = "",
401+
*,
402+
output_format: str = "exported_program",
403+
inputs: Optional[Sequence[torch.Tensor]] = None,
404+
retrace: bool = False,
405+
) -> None:
406+
"""
407+
Save the model to disk in the specified output format.
408+
Arguments:
409+
module : Compiled Torch-TensorRT module (Options include torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)
410+
inputs (torch.Tensor): Torch input tensors
411+
output_format: Format to save the model. Options include exported_program | torchscript.
412+
retrace: When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
413+
This flag is experimental for now.
414+
"""
415+
module_type = _parse_module_type(module)
416+
accepted_formats = {"exported_program", "torchscript"}
417+
if inputs is not None and not all(
418+
isinstance(input, torch.Tensor) for input in inputs
419+
):
420+
raise ValueError(
421+
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
422+
)
423+
if output_format not in accepted_formats:
424+
raise ValueError(
425+
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
426+
)
427+
if not file_path:
428+
raise ValueError("File path cannot be empty. Please provide a valid file path")
429+
430+
if module_type == _ModuleType.nn:
431+
raise ValueError(
432+
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
433+
)
434+
elif module_type == _ModuleType.ts:
435+
if output_format == "exported_program":
436+
raise ValueError(
437+
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
438+
)
439+
else:
440+
torch.jit.save(module, file_path)
441+
elif module_type == _ModuleType.ep:
442+
if output_format == "torchscript":
443+
raise ValueError(
444+
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
445+
)
446+
else:
447+
torch.export.save(module, file_path)
448+
elif module_type == _ModuleType.fx:
449+
if inputs is None:
450+
raise ValueError(
451+
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
452+
)
453+
# The module type is torch.fx.GraphModule
454+
if output_format == "torchscript":
455+
module_ts = torch.jit.trace(module, inputs)
456+
torch.jit.save(module_ts, file_path)
457+
else:
458+
if not retrace:
459+
from torch_tensorrt.dynamo._exporter import export
460+
461+
exp_program = export(module, inputs)
462+
torch.export.save(exp_program, file_path)
463+
else:
464+
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
465+
466+
with enable_torchbind_tracing():
467+
exp_program = torch.export.export(
468+
module, tuple(inputs), strict=False
469+
)
470+
torch.export.save(exp_program, file_path)

0 commit comments

Comments
 (0)