-
Notifications
You must be signed in to change notification settings - Fork 364
fix: Add support for fake tensors #1955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Add support for fake tensors #1955
Conversation
38e80b1
to
e1555bc
Compare
510552b
to
0861418
Compare
0861418
to
c6e677c
Compare
- Refactor `to_numpy` function to handle non-tensor inputs, avoiding fake tensor issue during compilation of constants - Add regression test case to elicit behavior
ba0468f
to
9305811
Compare
TRT = "trt" | ||
|
||
|
||
DataTypeEquivalence: Dict[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unified data type translator dictionary which takes a TRTDataType
and translates it to a Numpy, Torch, or TRT data type of the equivalent type and precision.
9305811
to
d84ccb9
Compare
from torch_tensorrt.dynamo import compile | ||
|
||
|
||
class TestFakeTensors(TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we enable the faketensor in this test? I do not quite understand the purpose of this test file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the changes in backends.py
, which remove all calls to @fake_tensor_unsupported
, fake tensors will be enabled by default via Dynamo/AOT. The purpose of this test is to verify that utilities like create_constant
do not instantiate Torch tensors when provided scalar inputs. For example, in the test test_lowering_mul_int
below, the only op in the graph will be something like:
call_function[target=torch.ops.aten.mul.Tensor](args=(%x, 7)...)
Without the changes in this PR, the above will fail at runtime because create_constant
will make a torch.Tensor
for the scalar 7, and this tensor will be fake (hold no values), so when TRT goes to extract the value to make a constant tensor, the script fails.
output = np.array([value], dtype=np.int32) | ||
|
||
elif isinstance(value, float): | ||
output = np.array([value], dtype=np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add support for fp16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this particular instance, the intent was just to take Python builtins and translate them to numpy arrays. If the caller specifies an FP16 dtype in the function schema, that cast will be applied later, on line 196.
Description
to_numpy
function to handle non-tensor inputs, avoiding fake tensor issue during compilation of constantsnumpy
data types and shapes, as a workaround fortorch.Tensor
instantiations being fake{TRT, Numpy, Torch}
data types and translates them to any other of thoseFixes #1951
Type of change
Checklist: