Skip to content

Commit e58f319

Browse files
authored
fix: Address runtimes with 0D inputs (#2188)
1 parent 56b8950 commit e58f319

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

core/runtime/execute_engine.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
136136
TORCHTRT_CHECK(
137137
inputs[i].dtype() == expected_type,
138138
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
139-
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
139+
auto dims = core::util::toDims(inputs[i].sizes());
140140
auto shape = core::util::toVec(dims);
141141
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
142142
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims);

py/torch_tensorrt/_Input.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -339,18 +339,18 @@ def from_tensor(
339339
A Input object.
340340
"""
341341
if not (
342-
t.is_contiguous(memory_format=torch.contiguous_format)
342+
disable_memory_format_check
343+
or t.is_contiguous(memory_format=torch.contiguous_format)
343344
or t.is_contiguous(memory_format=torch.channels_last)
344-
or disable_memory_format_check
345345
):
346346
raise ValueError(
347347
"Tensor does not have a supported memory format, supported formats are contiguous or channel_last"
348348
)
349349
frmt = (
350350
torch.contiguous_format
351351
if (
352-
t.is_contiguous(memory_format=torch.contiguous_format)
353-
or disable_memory_format_check
352+
disable_memory_format_check
353+
or t.is_contiguous(memory_format=torch.contiguous_format)
354354
)
355355
else torch.channels_last
356356
)

tests/py/dynamo/backend/test_specialized_models.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from utils import lower_graph_testing
2-
from torch.testing._internal.common_utils import run_tests, TestCase
31
import torch
42
import torch_tensorrt
3+
from torch.testing._internal.common_utils import TestCase, run_tests
4+
from utils import lower_graph_testing
55

66

77
class TestFakeTensors(TestCase):
@@ -118,5 +118,43 @@ def forward(self, x):
118118
torch._dynamo.reset()
119119

120120

121+
class Test0DTensors(TestCase):
122+
def test_0D_input(self):
123+
class Tensor0DInput(torch.nn.Module):
124+
def forward(self, x):
125+
return x * 7
126+
127+
inputs = [
128+
torch.tensor(
129+
3,
130+
)
131+
.cuda()
132+
.int(),
133+
]
134+
135+
fx_graph = torch.fx.symbolic_trace(Tensor0DInput())
136+
137+
# Validate that the results between Torch and Torch-TRT are similar
138+
optimized_model = torch_tensorrt.compile(
139+
fx_graph,
140+
"torch_compile",
141+
inputs,
142+
min_block_size=1,
143+
pass_through_build_failures=True,
144+
)
145+
optimized_model_results = optimized_model(*inputs).detach().cpu()
146+
torch_model_results = fx_graph(*inputs).detach().cpu()
147+
148+
max_diff = float(
149+
torch.max(torch.abs(optimized_model_results - torch_model_results))
150+
)
151+
self.assertAlmostEqual(
152+
max_diff,
153+
0,
154+
msg=f"0D-Tensor TRT outputs don't match with the original model.",
155+
)
156+
torch._dynamo.reset()
157+
158+
121159
if __name__ == "__main__":
122160
run_tests()

0 commit comments

Comments
 (0)