13
13
from torch .fx .passes .shape_prop import TensorMetadata
14
14
from torch .utils ._python_dispatch import _disable_current_modes
15
15
from torch_tensorrt ._Input import Input
16
- from torch_tensorrt .dynamo .conversion .converter_utils import get_node_name
16
+ from torch_tensorrt .dynamo ._settings import CompilationSettings
17
+ from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
18
+ from torch_tensorrt .dynamo .conversion .converter_registry import CallingConvention
19
+ from torch_tensorrt .dynamo .conversion .converter_utils import (
20
+ get_node_name ,
21
+ get_trt_tensor ,
22
+ )
17
23
from torch_tensorrt .fx .observer import Observer
18
24
from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
19
25
@@ -46,6 +52,7 @@ def __init__(
46
52
input_specs : List [Input ],
47
53
logger_level : trt .ILogger .Severity = trt .ILogger .Severity .WARNING ,
48
54
output_dtypes : Optional [List [torch .dtype ]] = None ,
55
+ compilation_settings : CompilationSettings = CompilationSettings (),
49
56
):
50
57
super ().__init__ (module )
51
58
@@ -59,7 +66,9 @@ def __init__(
59
66
EXPLICIT_BATCH = 1 << (int )(trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH )
60
67
flag |= EXPLICIT_BATCH
61
68
62
- self .network = self .builder .create_network (flag )
69
+ self .ctx = ConversionContext (
70
+ self .builder .create_network (flag ), compilation_settings
71
+ )
63
72
64
73
missing_ops = self .validate_conversion ()
65
74
if missing_ops :
@@ -95,14 +104,14 @@ def validate_conversion(self) -> Set[str]:
95
104
missing_converters : Set [str ] = set ()
96
105
97
106
for node in self .module .graph .nodes :
98
- if node .op == "call_function" and not CONVERTERS .get (node ):
107
+ if node .op == "call_function" and CONVERTERS .get (node ) is None :
99
108
missing_converters .add (f"{ node .op } { _get_qualified_name (node .target )} " )
100
- elif node .op == "call_method" and not CONVERTERS .get (node ):
109
+ elif node .op == "call_method" and CONVERTERS .get (node ) is None :
101
110
missing_converters .add (f"{ node .op } torch.Tensor.{ node .target } " )
102
111
elif node .op == "call_module" :
103
112
submod = self .fetch_attr (node .target )
104
113
submod_type = getattr (submod , "_base_class_origin" , type (submod ))
105
- if not CONVERTERS .get (node ):
114
+ if CONVERTERS .get (node ) is None :
106
115
missing_converters .add (f"{ node .op } { torch .typename (submod_type )} " )
107
116
108
117
return missing_converters
@@ -221,7 +230,7 @@ def run(
221
230
if tactic_sources is not None :
222
231
builder_config .set_tactic_sources (tactic_sources = tactic_sources )
223
232
224
- engine = self .builder .build_engine (self .network , builder_config )
233
+ engine = self .builder .build_engine (self .ctx . net , builder_config )
225
234
assert engine
226
235
227
236
serialized_cache = (
@@ -291,7 +300,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
291
300
f"Unable to access shape spec for input: { target } (got: { current_input } )"
292
301
)
293
302
294
- return self .network .add_input (
303
+ return self .ctx . net .add_input (
295
304
name = target ,
296
305
shape = tuple (shape ),
297
306
dtype = unified_dtype_converter (current_input .torch_dtype , Frameworks .TRT ),
@@ -303,30 +312,40 @@ def call_module(
303
312
assert isinstance (target , str )
304
313
submod = self .fetch_attr (target )
305
314
submod_type = getattr (submod , "_base_class_origin" , type (submod ))
306
- converter = CONVERTERS .get (self ._cur_node )
315
+ converter_packet = CONVERTERS .get (self ._cur_node )
307
316
308
- if not converter :
317
+ if converter_packet is None :
309
318
raise UnsupportedOperatorException (
310
319
f"Conversion of module of type { submod_type } not currently supported!"
311
320
)
312
321
322
+ converter , calling_convention = converter_packet
323
+
313
324
assert self ._cur_node_name is not None
314
- return converter (self .network , submod , args , kwargs , self ._cur_node_name )
325
+ if calling_convention is CallingConvention .LEGACY :
326
+ return converter (self .ctx .net , submod , args , kwargs , self ._cur_node_name )
327
+ else :
328
+ return converter (self .ctx , submod , args , kwargs , self ._cur_node_name )
315
329
316
330
def call_function (self , target : str , args : Any , kwargs : Any ) -> Any :
317
331
# TODO: Why is this stateful? We should be able to take in the inputs
318
- converter = CONVERTERS .get (self ._cur_node )
319
- if not converter :
332
+ converter_packet = CONVERTERS .get (self ._cur_node )
333
+ if converter_packet is None :
320
334
raise UnsupportedOperatorException (
321
335
f"Conversion of function { torch .typename (target )} not currently supported!"
322
336
)
323
337
338
+ converter , calling_convention = converter_packet
339
+
324
340
assert self ._cur_node_name is not None
325
- return converter (self .network , target , args , kwargs , self ._cur_node_name )
341
+ if calling_convention is CallingConvention .LEGACY :
342
+ return converter (self .ctx .net , target , args , kwargs , self ._cur_node_name )
343
+ else :
344
+ return converter (self .ctx , target , args , kwargs , self ._cur_node_name )
326
345
327
346
def get_attr (self , target : str , args : Any , kwargs : Any ) -> np .ndarray :
328
347
with _disable_current_modes ():
329
- from torch_tensorrt .fx . converters import to_numpy
348
+ from torch_tensorrt .dynamo . conversion . converter_utils import to_numpy
330
349
331
350
frozen_attr = self .fetch_attr (target )
332
351
@@ -341,15 +360,19 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
341
360
342
361
def call_method (self , target : str , args : Any , kwargs : Any ) -> Any :
343
362
assert isinstance (target , str )
344
- converter = CONVERTERS .get (self ._cur_node )
363
+ converter_packet = CONVERTERS .get (self ._cur_node )
345
364
346
- if not converter :
365
+ if converter_packet is None :
347
366
raise UnsupportedOperatorException (
348
367
f"Conversion of method { target } not currently supported!"
349
368
)
369
+ converter , calling_convention = converter_packet
350
370
351
371
assert self ._cur_node_name is not None
352
- return converter (self .network , target , args , kwargs , self ._cur_node_name )
372
+ if calling_convention is CallingConvention .LEGACY :
373
+ return converter (self .ctx .net , target , args , kwargs , self ._cur_node_name )
374
+ else :
375
+ return converter (self .ctx , target , args , kwargs , self ._cur_node_name )
353
376
354
377
def output (self , target : str , args : Any , kwargs : Any ) -> List [Any ]:
355
378
assert len (args ) == 1
@@ -361,12 +384,10 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
361
384
outputs = (args [0 ],)
362
385
363
386
for output_idx in range (len (outputs )):
364
- from torch_tensorrt .dynamo .conversion .converter_utils import get_trt_tensor
365
-
366
387
output = outputs [output_idx ]
367
388
368
389
if not isinstance (output , trt .tensorrt .ITensor ):
369
- new_output = get_trt_tensor (self .network , output , target )
390
+ new_output = get_trt_tensor (self .ctx , output , target )
370
391
outputs = (
371
392
outputs [:output_idx ] + (new_output ,) + outputs [output_idx + 1 :]
372
393
)
@@ -400,7 +421,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
400
421
output_bool = False
401
422
name = f"output{ i } "
402
423
output .name = name
403
- self .network .mark_output (output )
424
+ self .ctx . net .mark_output (output )
404
425
if output_bool :
405
426
output .dtype = trt .bool
406
427
elif self .output_dtypes is not None :
0 commit comments