Skip to content

Commit 42e514b

Browse files
authored
feat: Add tensor type enforcement for converters (#2324)
1 parent 5de208f commit 42e514b

36 files changed

+951
-616
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from dataclasses import dataclass, field
2+
3+
from torch_tensorrt.dynamo._settings import CompilationSettings
4+
from torch_tensorrt.fx.types import TRTNetwork
5+
6+
7+
@dataclass
8+
class ConversionContext:
9+
"""Class representing the context for conversion of a particular network
10+
11+
Args:
12+
net: TensorRT Network being built
13+
compilation_settings: Settings selected by the user for compilation
14+
"""
15+
16+
net: TRTNetwork
17+
compilation_settings: CompilationSettings = field(
18+
default_factory=CompilationSettings
19+
)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+42-21
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
from torch.fx.passes.shape_prop import TensorMetadata
1414
from torch.utils._python_dispatch import _disable_current_modes
1515
from torch_tensorrt._Input import Input
16-
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name
16+
from torch_tensorrt.dynamo._settings import CompilationSettings
17+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
18+
from torch_tensorrt.dynamo.conversion.converter_registry import CallingConvention
19+
from torch_tensorrt.dynamo.conversion.converter_utils import (
20+
get_node_name,
21+
get_trt_tensor,
22+
)
1723
from torch_tensorrt.fx.observer import Observer
1824
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1925

@@ -46,6 +52,7 @@ def __init__(
4652
input_specs: List[Input],
4753
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
4854
output_dtypes: Optional[List[torch.dtype]] = None,
55+
compilation_settings: CompilationSettings = CompilationSettings(),
4956
):
5057
super().__init__(module)
5158

@@ -59,7 +66,9 @@ def __init__(
5966
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
6067
flag |= EXPLICIT_BATCH
6168

62-
self.network = self.builder.create_network(flag)
69+
self.ctx = ConversionContext(
70+
self.builder.create_network(flag), compilation_settings
71+
)
6372

6473
missing_ops = self.validate_conversion()
6574
if missing_ops:
@@ -95,14 +104,14 @@ def validate_conversion(self) -> Set[str]:
95104
missing_converters: Set[str] = set()
96105

97106
for node in self.module.graph.nodes:
98-
if node.op == "call_function" and not CONVERTERS.get(node):
107+
if node.op == "call_function" and CONVERTERS.get(node) is None:
99108
missing_converters.add(f"{node.op} {_get_qualified_name(node.target)}")
100-
elif node.op == "call_method" and not CONVERTERS.get(node):
109+
elif node.op == "call_method" and CONVERTERS.get(node) is None:
101110
missing_converters.add(f"{node.op} torch.Tensor.{node.target}")
102111
elif node.op == "call_module":
103112
submod = self.fetch_attr(node.target)
104113
submod_type = getattr(submod, "_base_class_origin", type(submod))
105-
if not CONVERTERS.get(node):
114+
if CONVERTERS.get(node) is None:
106115
missing_converters.add(f"{node.op} {torch.typename(submod_type)}")
107116

108117
return missing_converters
@@ -221,7 +230,7 @@ def run(
221230
if tactic_sources is not None:
222231
builder_config.set_tactic_sources(tactic_sources=tactic_sources)
223232

224-
engine = self.builder.build_engine(self.network, builder_config)
233+
engine = self.builder.build_engine(self.ctx.net, builder_config)
225234
assert engine
226235

227236
serialized_cache = (
@@ -291,7 +300,7 @@ def placeholder(self, target: str, args: Any, kwargs: Any) -> trt.ITensor:
291300
f"Unable to access shape spec for input: {target} (got: {current_input})"
292301
)
293302

294-
return self.network.add_input(
303+
return self.ctx.net.add_input(
295304
name=target,
296305
shape=tuple(shape),
297306
dtype=unified_dtype_converter(current_input.torch_dtype, Frameworks.TRT),
@@ -303,30 +312,40 @@ def call_module(
303312
assert isinstance(target, str)
304313
submod = self.fetch_attr(target)
305314
submod_type = getattr(submod, "_base_class_origin", type(submod))
306-
converter = CONVERTERS.get(self._cur_node)
315+
converter_packet = CONVERTERS.get(self._cur_node)
307316

308-
if not converter:
317+
if converter_packet is None:
309318
raise UnsupportedOperatorException(
310319
f"Conversion of module of type {submod_type} not currently supported!"
311320
)
312321

322+
converter, calling_convention = converter_packet
323+
313324
assert self._cur_node_name is not None
314-
return converter(self.network, submod, args, kwargs, self._cur_node_name)
325+
if calling_convention is CallingConvention.LEGACY:
326+
return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name)
327+
else:
328+
return converter(self.ctx, submod, args, kwargs, self._cur_node_name)
315329

316330
def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
317331
# TODO: Why is this stateful? We should be able to take in the inputs
318-
converter = CONVERTERS.get(self._cur_node)
319-
if not converter:
332+
converter_packet = CONVERTERS.get(self._cur_node)
333+
if converter_packet is None:
320334
raise UnsupportedOperatorException(
321335
f"Conversion of function {torch.typename(target)} not currently supported!"
322336
)
323337

338+
converter, calling_convention = converter_packet
339+
324340
assert self._cur_node_name is not None
325-
return converter(self.network, target, args, kwargs, self._cur_node_name)
341+
if calling_convention is CallingConvention.LEGACY:
342+
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
343+
else:
344+
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
326345

327346
def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
328347
with _disable_current_modes():
329-
from torch_tensorrt.fx.converters import to_numpy
348+
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy
330349

331350
frozen_attr = self.fetch_attr(target)
332351

@@ -341,15 +360,19 @@ def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
341360

342361
def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
343362
assert isinstance(target, str)
344-
converter = CONVERTERS.get(self._cur_node)
363+
converter_packet = CONVERTERS.get(self._cur_node)
345364

346-
if not converter:
365+
if converter_packet is None:
347366
raise UnsupportedOperatorException(
348367
f"Conversion of method {target} not currently supported!"
349368
)
369+
converter, calling_convention = converter_packet
350370

351371
assert self._cur_node_name is not None
352-
return converter(self.network, target, args, kwargs, self._cur_node_name)
372+
if calling_convention is CallingConvention.LEGACY:
373+
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
374+
else:
375+
return converter(self.ctx, target, args, kwargs, self._cur_node_name)
353376

354377
def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
355378
assert len(args) == 1
@@ -361,12 +384,10 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
361384
outputs = (args[0],)
362385

363386
for output_idx in range(len(outputs)):
364-
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
365-
366387
output = outputs[output_idx]
367388

368389
if not isinstance(output, trt.tensorrt.ITensor):
369-
new_output = get_trt_tensor(self.network, output, target)
390+
new_output = get_trt_tensor(self.ctx, output, target)
370391
outputs = (
371392
outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :]
372393
)
@@ -400,7 +421,7 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
400421
output_bool = False
401422
name = f"output{i}"
402423
output.name = name
403-
self.network.mark_output(output)
424+
self.ctx.net.mark_output(output)
404425
if output_bool:
405426
output.dtype = trt.bool
406427
elif self.output_dtypes is not None:

py/torch_tensorrt/dynamo/conversion/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ._ConversionContext import ConversionContext
12
from ._TRTInterpreter import * # noqa: F403
23
from .aten_ops_converters import * # noqa: F403
34
from .conversion import * # noqa: F403

0 commit comments

Comments
 (0)