Skip to content

fix: Add support for fake tensors #1955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 28, 2023

Conversation

gs-olive
Copy link
Collaborator

@gs-olive gs-olive commented May 26, 2023

Description

  • Refactor to_numpy function to handle non-tensor inputs, avoiding fake tensor issue during compilation of constants
  • Refactor multiple functions to allow for numpy data types and shapes, as a workaround for torch.Tensor instantiations being fake
  • Develop new "unified" data type translation which takes any of {TRT, Numpy, Torch} data types and translates them to any other of those
  • Add regression test case to elicit behavior

Fixes #1951

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • [ x ] My code follows the style guidelines of this project (You can use the linters)
  • [ x ] I have performed a self-review of my own code
  • [ x ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] I have added tests to verify my fix or my feature
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have added the relevant labels to my PR in so that relevant reviewers are notified

@gs-olive gs-olive added the component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths label May 26, 2023
@gs-olive gs-olive self-assigned this May 26, 2023
@github-actions github-actions bot requested a review from yinghai May 26, 2023 16:37
@gs-olive gs-olive removed the request for review from yinghai May 26, 2023 16:37
@gs-olive gs-olive force-pushed the remove_fake_tensor_unsupported branch from 38e80b1 to e1555bc Compare May 26, 2023 17:05
@gs-olive gs-olive added the WIP Work is in progress, pull request should not be merged yet label May 26, 2023
@github-actions github-actions bot requested a review from yinghai May 26, 2023 18:20
@gs-olive gs-olive force-pushed the remove_fake_tensor_unsupported branch 2 times, most recently from 510552b to 0861418 Compare May 26, 2023 20:35
@gs-olive gs-olive removed the request for review from yinghai May 26, 2023 20:36
@gs-olive gs-olive added the Story: Dynamo Compile Improvements Issues relating to improvement of the Dynamo compile path label May 26, 2023
@github-actions github-actions bot requested a review from yinghai May 26, 2023 20:40
@gs-olive gs-olive removed the request for review from yinghai May 26, 2023 20:46
@gs-olive gs-olive force-pushed the remove_fake_tensor_unsupported branch from 0861418 to c6e677c Compare May 26, 2023 21:39
- Refactor `to_numpy` function to handle non-tensor inputs, avoiding
fake tensor issue during compilation of constants
- Add regression test case to elicit behavior
@gs-olive gs-olive force-pushed the remove_fake_tensor_unsupported branch 2 times, most recently from ba0468f to 9305811 Compare June 6, 2023 02:50
TRT = "trt"


DataTypeEquivalence: Dict[
Copy link
Collaborator Author

@gs-olive gs-olive Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unified data type translator dictionary which takes a TRTDataType and translates it to a Numpy, Torch, or TRT data type of the equivalent type and precision.

@gs-olive gs-olive force-pushed the remove_fake_tensor_unsupported branch from 9305811 to d84ccb9 Compare June 7, 2023 05:16
@gs-olive gs-olive removed the WIP Work is in progress, pull request should not be merged yet label Jun 7, 2023
@gs-olive gs-olive requested a review from frank-wei June 7, 2023 05:16
@gs-olive gs-olive marked this pull request as ready for review June 7, 2023 05:16
@gs-olive gs-olive added the WIP Work is in progress, pull request should not be merged yet label Jun 7, 2023
@github-actions github-actions bot requested a review from wushirong June 7, 2023 18:24
from torch_tensorrt.dynamo import compile


class TestFakeTensors(TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we enable the faketensor in this test? I do not quite understand the purpose of this test file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the changes in backends.py, which remove all calls to @fake_tensor_unsupported, fake tensors will be enabled by default via Dynamo/AOT. The purpose of this test is to verify that utilities like create_constant do not instantiate Torch tensors when provided scalar inputs. For example, in the test test_lowering_mul_int below, the only op in the graph will be something like:

call_function[target=torch.ops.aten.mul.Tensor](args=(%x, 7)...)

Without the changes in this PR, the above will fail at runtime because create_constant will make a torch.Tensor for the scalar 7, and this tensor will be fake (hold no values), so when TRT goes to extract the value to make a constant tensor, the script fails.

output = np.array([value], dtype=np.int32)

elif isinstance(value, float):
output = np.array([value], dtype=np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add support for fp16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this particular instance, the intent was just to take Python builtins and translate them to numpy arrays. If the caller specifies an FP16 dtype in the function schema, that cast will be applied later, on line 196.

@gs-olive gs-olive requested a review from frank-wei June 9, 2023 16:42
@gs-olive gs-olive removed the WIP Work is in progress, pull request should not be merged yet label Jun 27, 2023
@gs-olive gs-olive merged commit 6dcd1fc into pytorch:main Jun 28, 2023
@gs-olive gs-olive deleted the remove_fake_tensor_unsupported branch June 28, 2023 21:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: fx fx Story: Dynamo Compile Improvements Issues relating to improvement of the Dynamo compile path
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Fully support FakeTensors in Dynamo compile
3 participants