Skip to content

Commit 574368c

Browse files
authored
chore: cherry pick of save API (#2719)
1 parent ad0d786 commit 574368c

File tree

11 files changed

+356
-295
lines changed

11 files changed

+356
-295
lines changed

core/runtime/TRTEngine.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ std::string TRTEngine::to_str() const {
241241
exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
242242
<< std::endl;
243243
}
244-
ss << " }" << std::endl;
244+
ss << " ]" << std::endl;
245245
ss << " Device: " << device_info << std::endl;
246246
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
247247
// clang-format on

docsrc/user_guide/saving_models.rst

+35-36
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,22 @@ Saving models compiled with Torch-TensorRT
99
:undoc-members:
1010
:show-inheritance:
1111

12-
Saving models compiled with Torch-TensorRT varies slightly with the `ir` that has been used for compilation.
12+
Saving models compiled with Torch-TensorRT can be done using `torch_tensorrt.save` API.
1313

1414
Dynamo IR
1515
-------------
1616

17-
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.export.ExportedProgram` object by default.
18-
In addition, we provide a new parameter `output_format` in the `CompilationSetting` object provided before compilation.
19-
The `output_format` can take the following options
17+
The output type of `ir=dynamo` compilation of Torch-TensorRT is `torch.fx.GraphModule` object by default.
18+
We can save this object in either `TorchScript` (`torch.jit.ScriptModule`) or `ExportedProgram` (`torch.export.ExportedProgram`) formats by
19+
specifying the `output_format` flag. Here are the options `output_format` will accept
2020

21-
* `exported_program` (or) `ep` : This is the default. Returns an ExportedProgram
22-
* `torchscript` (or) `ts` : This returns a TorchScript module
23-
* `graph_module` (or) `fx` : This returns a torch.fx.GraphModule which can be traced into Torchscript to save to disk.
21+
* `exported_program` : This is the default. We perform transformations on the graphmodule first and use `torch.export.save` to save the module.
22+
* `torchscript` : We trace the graphmodule via `torch.jit.trace` and save it via `torch.jit.save`.
2423

25-
a) Torchscript
24+
a) ExportedProgram
2625
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2726

28-
If you set the `output_format="torchscript"`, this will return a `ScriptModule` which can be serialized via torch.jit.save
27+
Here's an example usage
2928

3029
.. code-block:: python
3130
@@ -34,50 +33,32 @@ If you set the `output_format="torchscript"`, this will return a `ScriptModule`
3433
3534
model = MyModel().eval().cuda()
3635
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
37-
# trt_ts is a torch.jit.ScriptModule object
38-
trt_ts = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="torchscript")
39-
torch.jit.save(trt_ts, "trt_model.ts")
36+
# trt_ep is a torch.fx.GraphModule object
37+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
38+
torchtrt.save(trt_gm, "trt.ep", inputs=inputs)
4039
4140
# Later, you can load it and run inference
42-
model = torch.jit.load("trt_model.ts").cuda()
41+
model = torch.export.load("trt.ep").module()
4342
model(*inputs)
4443
45-
b) ExportedProgram
44+
b) Torchscript
4645
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4746

48-
`torch.export.ExportedProgram`, a new format introduced in Pytorch 2.X is the default return type of Torch-TensorRT compilation.
49-
5047
.. code-block:: python
5148
5249
import torch
5350
import torch_tensorrt
5451
5552
model = MyModel().eval().cuda()
5653
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
57-
# trt_ep is a torch.export.ExportedProgram object
58-
trt_ep = torch_tensorrt.compile(model, ir="dynamo", inputs)
59-
torch.export.save(trt_ep, "trt_model.ep")
54+
# trt_gm is a torch.fx.GraphModule object
55+
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs)
56+
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
6057
6158
# Later, you can load it and run inference
62-
model = torch.export.load("trt_model.ep")
59+
model = torch.jit.load("trt.ts").cuda()
6360
model(*inputs)
6461
65-
c) GraphModule
66-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
67-
68-
We can also return a `torch.fx.GraphModule` object as the output of Torch-TensorRT compilation by setting `output_format="graph_module"`.
69-
Internally, partitioning, lowering, conversion phases operate using GraphModule objects. These can be either traced into a Torchscript modules or
70-
exported into `ExportedProgram` objects
71-
72-
.. code-block:: python
73-
74-
import torch
75-
import torch_tensorrt
76-
77-
model = MyModel().eval().cuda()
78-
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
79-
# trt_gm is a torch.fx.GraphModule object
80-
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs, output_format="graph_module")
8162
8263
Torchscript IR
8364
-------------
@@ -99,3 +80,21 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well.
9980
model = torch.jit.load("trt_model.ts").cuda()
10081
model(*inputs)
10182
83+
84+
Loading the models
85+
--------------------
86+
87+
We can load torchscript or exported_program models using `torch.jit.load` and `torch.export.load` APIs from PyTorch directly.
88+
Alternatively, we provide a light wrapper `torch_tensorrt.load(file_path)` which can load either of the above model types.
89+
90+
Here's an example usage
91+
92+
.. code-block:: python
93+
94+
import torch
95+
import torch_tensorrt
96+
97+
# file_path can be trt.ep or trt.ts file obtained via saving the model (refer to the above section)
98+
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
99+
model = torch_tensorrt.load(<file_path>).module()
100+
model(*inputs)

py/torch_tensorrt/_compile.py

+106-4
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@
3333

3434
logger = logging.getLogger(__name__)
3535

36-
__all__ = [
37-
"compile",
38-
"convert_method_to_trt_engine",
39-
]
36+
__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"]
4037

4138

4239
def _non_fx_input_interface(
@@ -366,3 +363,108 @@ def convert_method_to_trt_engine(
366363
)
367364
else:
368365
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
366+
367+
368+
def load(file_path: str = "") -> Any:
369+
"""
370+
Load either a Torchscript model or ExportedProgram. Autodetect the type using
371+
try, except
372+
"""
373+
try:
374+
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
375+
ts_module = torch.jit.load(file_path)
376+
return ts_module
377+
except Exception:
378+
logger.info(
379+
f"Loading the provided file {file_path} via torch.jit.load() failed with the following error",
380+
exc_info=True,
381+
)
382+
pass
383+
384+
try:
385+
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
386+
exp_program = torch.export.load(file_path)
387+
return exp_program
388+
except Exception:
389+
logger.info(
390+
f"Loading the provided file {file_path} via torch.export.load() failed with the following error",
391+
exc_info=True,
392+
)
393+
raise ValueError(
394+
f"The file {file_path} doesn't correspond to a valid Torchscript module or ExportedProgram. Please verify the file path."
395+
)
396+
397+
398+
def save(
399+
module: Any,
400+
file_path: str = "",
401+
*,
402+
output_format: str = "exported_program",
403+
inputs: Optional[Sequence[torch.Tensor]] = None,
404+
retrace: bool = False,
405+
) -> None:
406+
"""
407+
Save the model to disk in the specified output format.
408+
Arguments:
409+
module : Compiled Torch-TensorRT module (Options include torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)
410+
inputs (torch.Tensor): Torch input tensors
411+
output_format: Format to save the model. Options include exported_program | torchscript.
412+
retrace: When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
413+
This flag is experimental for now.
414+
"""
415+
module_type = _parse_module_type(module)
416+
accepted_formats = {"exported_program", "torchscript"}
417+
if inputs is not None and not all(
418+
isinstance(input, torch.Tensor) for input in inputs
419+
):
420+
raise ValueError(
421+
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
422+
)
423+
if output_format not in accepted_formats:
424+
raise ValueError(
425+
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
426+
)
427+
if not file_path:
428+
raise ValueError("File path cannot be empty. Please provide a valid file path")
429+
430+
if module_type == _ModuleType.nn:
431+
raise ValueError(
432+
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
433+
)
434+
elif module_type == _ModuleType.ts:
435+
if output_format == "exported_program":
436+
raise ValueError(
437+
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
438+
)
439+
else:
440+
torch.jit.save(module, file_path)
441+
elif module_type == _ModuleType.ep:
442+
if output_format == "torchscript":
443+
raise ValueError(
444+
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
445+
)
446+
else:
447+
torch.export.save(module, file_path)
448+
elif module_type == _ModuleType.fx:
449+
if inputs is None:
450+
raise ValueError(
451+
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
452+
)
453+
# The module type is torch.fx.GraphModule
454+
if output_format == "torchscript":
455+
module_ts = torch.jit.trace(module, inputs)
456+
torch.jit.save(module_ts, file_path)
457+
else:
458+
if not retrace:
459+
from torch_tensorrt.dynamo._exporter import export
460+
461+
exp_program = export(module, inputs)
462+
torch.export.save(exp_program, file_path)
463+
else:
464+
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
465+
466+
with enable_torchbind_tracing():
467+
exp_program = torch.export.export(
468+
module, tuple(inputs), strict=False
469+
)
470+
torch.export.save(exp_program, file_path)

py/torch_tensorrt/dynamo/_compiler.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
dryrun_stats_display,
1818
parse_non_trt_nodes,
1919
)
20-
from torch_tensorrt.dynamo._exporter import export
2120
from torch_tensorrt.dynamo.conversion import (
2221
CompilationSettings,
2322
UnsupportedOperatorException,
@@ -73,9 +72,8 @@ def compile(
7372
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
7473
dryrun: bool = _defaults.DRYRUN,
7574
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
76-
output_format: str = _defaults.OUTPUT_FORMAT,
7775
**kwargs: Any,
78-
) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]:
76+
) -> torch.fx.GraphModule:
7977
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
8078
8179
Takes a existing TorchScript module and a set of settings to configure the compiler
@@ -132,7 +130,6 @@ def compile(
132130
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
133131
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
134132
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
135-
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
136133
**kwargs: Any,
137134
Returns:
138135
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -202,14 +199,12 @@ def compile(
202199
"dla_global_dram_size": dla_global_dram_size,
203200
"dryrun": dryrun,
204201
"hardware_compatible": hardware_compatible,
205-
"output_format": output_format,
206202
}
207203

208204
settings = CompilationSettings(**compilation_options)
209205
logger.info("Compilation Settings: %s\n", settings)
210206
trt_gm = compile_module(gm, inputs, settings)
211-
trt_result = export(trt_gm, torch_inputs, output_format)
212-
return trt_result
207+
return trt_gm
213208

214209

215210
def compile_module(

py/torch_tensorrt/dynamo/_defaults.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
REQUIRE_FULL_COMPILATION = False
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
29-
OUTPUT_FORMAT = "exported_program"
3029
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8}
3130

3231

0 commit comments

Comments
 (0)