Skip to content

Commit d177389

Browse files
committed
fix: Address compiling models with 0D inputs
- Remove dimension padding on inputs to TRT engines in runtime - Reorder memory format checks to exempt 0D FakeTensors - Add regression test
1 parent b3089bf commit d177389

File tree

3 files changed

+50
-12
lines changed

3 files changed

+50
-12
lines changed

core/runtime/execute_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
136136
TORCHTRT_CHECK(
137137
inputs[i].dtype() == expected_type,
138138
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
139-
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
139+
auto dims = core::util::toDims(inputs[i].sizes());
140140
auto shape = core::util::toVec(dims);
141141
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
142142
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);

py/torch_tensorrt/_Input.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ class _ShapeMode(Enum):
3232
shape: Optional[
3333
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
3434
] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
35-
dtype: _enums.dtype = ( # type: ignore[name-defined]
35+
dtype: _enums.dtype = (
3636
_enums.dtype.unknown
3737
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
3838
_explicit_set_dtype: bool = False
39-
format: _enums.TensorFormat = ( # type: ignore[name-defined]
39+
format: _enums.TensorFormat = (
4040
_enums.TensorFormat.contiguous
4141
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
4242

@@ -208,7 +208,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
208208
return False
209209

210210
@staticmethod
211-
def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
211+
def _parse_dtype(dtype: Any) -> _enums.dtype:
212212
if isinstance(dtype, torch.dtype):
213213
if dtype == torch.long:
214214
return _enums.dtype.long
@@ -236,7 +236,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
236236
)
237237

238238
@staticmethod
239-
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined]
239+
def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
240240
if dtype == _enums.dtype.long:
241241
return torch.long
242242
elif dtype == _enums.dtype.int32:
@@ -255,7 +255,7 @@ def is_trt_dtype(self) -> bool:
255255
return bool(self.dtype != _enums.dtype.long)
256256

257257
@staticmethod
258-
def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined]
258+
def _parse_format(format: Any) -> _enums.TensorFormat:
259259
if isinstance(format, torch.memory_format):
260260
if format == torch.contiguous_format:
261261
return _enums.TensorFormat.contiguous
@@ -337,18 +337,18 @@ def from_tensor(
337337
A Input object.
338338
"""
339339
if not (
340-
t.is_contiguous(memory_format=torch.contiguous_format)
340+
disable_memory_format_check
341+
or t.is_contiguous(memory_format=torch.contiguous_format)
341342
or t.is_contiguous(memory_format=torch.channels_last)
342-
or disable_memory_format_check
343343
):
344344
raise ValueError(
345345
"Tensor does not have a supported memory format, supported formats are contiguous or channel_last"
346346
)
347347
frmt = (
348348
torch.contiguous_format
349349
if (
350-
t.is_contiguous(memory_format=torch.contiguous_format)
351-
or disable_memory_format_check
350+
disable_memory_format_check
351+
or t.is_contiguous(memory_format=torch.contiguous_format)
352352
)
353353
else torch.channels_last
354354
)

tests/py/dynamo/backend/test_specialized_models.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from utils import lower_graph_testing
2-
from torch.testing._internal.common_utils import run_tests, TestCase
31
import torch
42
import torch_tensorrt
3+
from torch.testing._internal.common_utils import TestCase, run_tests
4+
from utils import lower_graph_testing
55

66

77
class TestFakeTensors(TestCase):
@@ -118,5 +118,43 @@ def forward(self, x):
118118
torch._dynamo.reset()
119119

120120

121+
class Test0DTensors(TestCase):
122+
def test_0D_input(self):
123+
class Tensor0DInput(torch.nn.Module):
124+
def forward(self, x):
125+
return x * 7
126+
127+
inputs = [
128+
torch.tensor(
129+
3,
130+
)
131+
.cuda()
132+
.int(),
133+
]
134+
135+
fx_graph = torch.fx.symbolic_trace(Tensor0DInput())
136+
137+
# Validate that the results between Torch and Torch-TRT are similar
138+
optimized_model = torch_tensorrt.compile(
139+
fx_graph,
140+
"torch_compile",
141+
inputs,
142+
min_block_size=1,
143+
pass_through_build_failures=True,
144+
)
145+
optimized_model_results = optimized_model(*inputs).detach().cpu()
146+
torch_model_results = fx_graph(*inputs).detach().cpu()
147+
148+
max_diff = float(
149+
torch.max(torch.abs(optimized_model_results - torch_model_results))
150+
)
151+
self.assertAlmostEqual(
152+
max_diff,
153+
0,
154+
msg=f"0D-Tensor TRT outputs don't match with the original model.",
155+
)
156+
torch._dynamo.reset()
157+
158+
121159
if __name__ == "__main__":
122160
run_tests()

0 commit comments

Comments
 (0)