Skip to content

fix: Add support for fake tensors #1955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
)
from torch_tensorrt.dynamo.backend.conversion import convert_module

from torch._dynamo.backends.common import fake_tensor_unsupported

from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler


logger = logging.getLogger(__name__)


@td.register_backend(name="torch_tensorrt")
@fake_tensor_unsupported
def torch_tensorrt_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
Expand All @@ -35,7 +32,6 @@ def torch_tensorrt_backend(


@td.register_backend(name="aot_torch_tensorrt_aten")
@fake_tensor_unsupported
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
Expand All @@ -55,7 +51,6 @@ def aot_torch_tensorrt_aten_backend(
)


@fake_tensor_unsupported
def _pretraced_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
Expand Down
114 changes: 114 additions & 0 deletions py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from utils import lower_graph_testing
from torch.testing._internal.common_utils import run_tests, TestCase
import torch
from torch_tensorrt.dynamo import compile


class TestFakeTensors(TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we enable the faketensor in this test? I do not quite understand the purpose of this test file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the changes in backends.py, which remove all calls to @fake_tensor_unsupported, fake tensors will be enabled by default via Dynamo/AOT. The purpose of this test is to verify that utilities like create_constant do not instantiate Torch tensors when provided scalar inputs. For example, in the test test_lowering_mul_int below, the only op in the graph will be something like:

call_function[target=torch.ops.aten.mul.Tensor](args=(%x, 7)...)

Without the changes in this PR, the above will fail at runtime because create_constant will make a torch.Tensor for the scalar 7, and this tensor will be fake (hold no values), so when TRT goes to extract the value to make a constant tensor, the script fails.

def test_lowering_mul_int(self):
class MulInt(torch.nn.Module):
def forward(self, x):
return x * 7

# Operations expected to be included in the traced graph after decompositions
expected_ops = {
torch.ops.aten.mul.Tensor,
}

inputs = [
torch.rand(
3,
5,
7,
).cuda(),
]

fx_graph = torch.fx.symbolic_trace(MulInt())
_, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
min_block_size=1,
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
msg=f"MulInt TRT outputs don't match with the original model.",
)
torch._dynamo.reset()

def test_lowering_add_float(self):
class AddFloat(torch.nn.Module):
def forward(self, x):
return x + 84.0

# Operations expected to be included in the traced graph after decompositions
expected_ops = {
torch.ops.aten.add.Tensor,
}

inputs = [
torch.rand(
1,
5,
7,
9,
).cuda(),
]

fx_graph = torch.fx.symbolic_trace(AddFloat())
_, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
min_block_size=1,
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = compile(
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
msg=f"AddFloat TRT outputs don't match with the original model.",
)

torch._dynamo.reset()


if __name__ == "__main__":
run_tests()
11 changes: 9 additions & 2 deletions py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS
from .input_tensor_spec import InputTensorSpec
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.fx.utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
from torch_tensorrt.fx.utils import (
get_dynamic_dims,
LowerPrecision,
unified_dtype_converter,
Frameworks,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -321,7 +326,9 @@ def placeholder(self, target, args, kwargs):
self.optimization_profiles[i].set_shape(target, *shape_range)

return self.network.add_input(
name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
name=target,
shape=tuple(shape),
dtype=unified_dtype_converter(dtype, Frameworks.TRT),
)

def call_module(self, target, args, kwargs):
Expand Down
29 changes: 18 additions & 11 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.fx.immutable_collections import immutable_list
from torch.fx.node import Argument, Target

from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks

from .converter_utils import * # noqa: F403
from torch_tensorrt.fx.passes.lower_basic_pass import (
Expand Down Expand Up @@ -400,7 +400,7 @@ def acc_ops_pad_with_slice_layer(
)

# cast value to TRTensor
dt = torch_dtype_from_trt(input_val.dtype)
dt = unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
value = 0 if value == None else value
value_const = get_trt_tensor(
network, torch.tensor([value], dtype=dt), f"{name}_value"
Expand Down Expand Up @@ -1550,7 +1550,7 @@ def acc_ops_to_dtype(
input_t = get_trt_tensor(network, input_val, f"{name}_input_t")
if input_dtype:
if isinstance(input_dtype, torch.dtype):
input_dtype = torch_dtype_to_trt(input_dtype)
input_dtype = unified_dtype_converter(input_dtype, Frameworks.TRT)
input_t = type_cast(network, target, f"{name}_input", input_t, input_dtype)
return input_t

Expand Down Expand Up @@ -1811,7 +1811,7 @@ def acc_ops_logical_xor(
# f"isinf received input {input_t} that is not part "
# "of the TensorRT region!"
# )
# tdtype = torch_dtype_from_trt(input_t.dtype)
# tdtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)

# inf_t = torch.ones(tuple(input_t.shape))
# inf_t = inf_t * float("inf")
Expand Down Expand Up @@ -1849,7 +1849,7 @@ def acc_ops_any(

if input_t.dtype in (trt.float32, trt.float16, trt.int32):
comp_t = torch.zeros(tuple([*input_t.shape])).to(
torch_dtype_from_trt(input_t.dtype)
unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
)
comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t")
kwargs_new = {"input": input_t, "other": comp_t}
Expand Down Expand Up @@ -2738,7 +2738,7 @@ def acc_ops_masked_fill_tensor(
if type(value_t) is torch.Tensor:
value_t = value_t.cpu().numpy()
# cast to input type
input_dtype = torch_dtype_from_trt(input_t.dtype)
input_dtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
value_t = (torch.ones(shape) * value_t).to(input_dtype)
input_val = get_trt_tensor(network, input_t, f"{name}_input")
value_val = get_trt_tensor(network, value_t, f"{name}_input")
Expand Down Expand Up @@ -2872,7 +2872,11 @@ def add_clamp(network, input, val, op, name):
# clamping scalar
acc_ops_clamp_trt = get_trt_tensor(
network,
squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))),
squeeze_left(
torch.tensor(
[val], dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH)
)
),
f"{name}_clamp_{val}",
)
else:
Expand All @@ -2881,7 +2885,8 @@ def add_clamp(network, input, val, op, name):
(
val
* torch.ones(
acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)
acc_ops_clamp_shape,
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
)
)
.cpu()
Expand Down Expand Up @@ -3527,7 +3532,9 @@ def acc_ops_cumsum(
iterator = loop.add_iterator(input_val, dim, False)
data = iterator.get_output(0)
new_dims = tuple(data.shape)
zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype))
zero_tensor = torch.zeros(
new_dims, dtype=unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
)
zero_tensor = network.add_constant(
zero_tensor.shape, to_numpy(zero_tensor)
).get_output(0)
Expand Down Expand Up @@ -3670,7 +3677,7 @@ def acc_ops_new_ones(
dtype_val = kwargs.get("dtype")
if dtype_val is None:
dtype_val = input_val.dtype
dtype_val = torch_dtype_from_trt(dtype_val)
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)

device_val = kwargs.get("device")
assert (
Expand All @@ -3694,7 +3701,7 @@ def acc_ops_new_empty(
dtype_val = kwargs.get("dtype")
if dtype_val is None:
dtype_val = input_val.dtype
dtype_val = torch_dtype_from_trt(dtype_val)
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)

device_val = kwargs.get("device")
assert (
Expand Down
2 changes: 0 additions & 2 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from torch.fx.immutable_collections import immutable_list
from torch.fx.node import Argument, Target

from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt

from .converter_utils import * # noqa: F403
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters.impl import activation
Expand Down
Loading