4
4
from functools import partial
5
5
import torch ._dynamo as td
6
6
7
- from torch_tensorrt .dynamo .torch_compile ._settings import CompilationSettings
8
- from torch_tensorrt .dynamo .torch_compile .lowering ._decompositions import (
7
+ from torch_tensorrt .dynamo .backend ._settings import CompilationSettings
8
+ from torch_tensorrt .dynamo .backend .lowering ._decompositions import (
9
9
get_decompositions ,
10
10
)
11
- from torch_tensorrt .dynamo .torch_compile .lowering ._partition import (
11
+ from torch_tensorrt .dynamo .backend .lowering ._partition import (
12
12
partition ,
13
13
get_submod_inputs ,
14
14
)
15
- from torch_tensorrt .dynamo .torch_compile .conversion import convert_module
15
+ from torch_tensorrt .dynamo .backend .conversion import convert_module
16
16
17
17
from torch ._dynamo .backends .common import fake_tensor_unsupported
18
18
19
19
from torch ._functorch .aot_autograd import aot_module_simplified , make_boxed_compiler
20
20
21
21
22
- @td .register_backend (name = "tensorrt " )
22
+ @td .register_backend (name = "torch_tensorrt " )
23
23
@fake_tensor_unsupported
24
- def tensorrt_backend (
25
- gm : torch .nn .Module ,
24
+ def torch_tensorrt_backend (
25
+ gm : torch .fx .GraphModule ,
26
+ sample_inputs : Sequence [torch .Tensor ],
27
+ settings : CompilationSettings = CompilationSettings (),
28
+ ):
29
+ DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
30
+
31
+ return DEFAULT_BACKEND (gm = gm , sample_inputs = sample_inputs , settings = settings )
32
+
33
+
34
+ @td .register_backend (name = "aot_torch_tensorrt_aten" )
35
+ @fake_tensor_unsupported
36
+ def aot_torch_tensorrt_aten_backend (
37
+ gm : torch .fx .GraphModule ,
26
38
sample_inputs : Sequence [torch .Tensor ],
27
39
settings : CompilationSettings = CompilationSettings (),
28
40
):
29
41
custom_backend = partial (
30
- fx_dynamo_backend ,
42
+ _pretraced_backend ,
31
43
settings = settings ,
32
44
)
33
45
@@ -40,14 +52,12 @@ def tensorrt_backend(
40
52
)
41
53
42
54
43
- @td .register_backend (name = "fx_tensorrt" )
44
- @fake_tensor_unsupported
45
- def fx_dynamo_backend (
55
+ def _pretraced_backend (
46
56
gm : torch .fx .GraphModule ,
47
- example_inputs : Sequence [torch .Tensor ],
57
+ sample_inputs : Sequence [torch .Tensor ],
48
58
settings : CompilationSettings = CompilationSettings (),
49
59
):
50
- """Helper function to manage translation of FX module to TRT engines
60
+ """Helper function to manage translation of traced FX module to TRT engines
51
61
52
62
Args:
53
63
module: FX GraphModule to convert
@@ -57,9 +67,9 @@ def fx_dynamo_backend(
57
67
Compiled FX GraphModule
58
68
"""
59
69
try :
60
- trt_compiled = compile_module (
70
+ trt_compiled = _compile_module (
61
71
gm ,
62
- example_inputs ,
72
+ sample_inputs ,
63
73
settings = settings ,
64
74
)
65
75
return trt_compiled
@@ -72,12 +82,12 @@ def fx_dynamo_backend(
72
82
return gm .forward
73
83
74
84
75
- def compile_module (
85
+ def _compile_module (
76
86
gm : torch .fx .GraphModule ,
77
- example_inputs : Sequence [torch .Tensor ],
87
+ sample_inputs : Sequence [torch .Tensor ],
78
88
settings : CompilationSettings = CompilationSettings (),
79
89
) -> torch .fx .GraphModule :
80
- """Compile an FX module
90
+ """Compile a traced FX module
81
91
82
92
Includes: Partitioning + Conversion Phases
83
93
@@ -100,7 +110,7 @@ def compile_module(
100
110
101
111
# Get submodule inputs
102
112
submodule_inputs = get_submod_inputs (
103
- partitioned_module , submodule , example_inputs
113
+ partitioned_module , submodule , sample_inputs
104
114
)
105
115
106
116
# Create TRT Module from submodule
0 commit comments