Skip to content

fix: Fix static arange export #3194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ def lift(
)
assert fake_mode is not None

# This map stores the names of outputs (old to new)
# This is necessary to track because the output names can be changed when
# we convert graph constants to placeholder inputs below.
output_names = {}
for output_spec in graph_signature.output_specs:
output_names[output_spec.arg.name] = output_spec.arg.name

# Locate the user input to insert new placeholders before them
first_user_input = None
for node in gm.graph.nodes:
Expand Down Expand Up @@ -139,9 +146,8 @@ def lift(
# Replace get_attr nodes with placeholder nodes and copy metadata.
with gm.graph.inserting_before(first_user_input):
# Ensure name doesn't contain period as it is used for submodules
const_placeholder_node = gm.graph.placeholder(
node.target.replace(".", "_")
)
const_placeholder_name = node.target.replace(".", "_")
const_placeholder_node = gm.graph.placeholder(const_placeholder_name)
# Copy the node meta into this new placeholder node
const_placeholder_node.meta = node.meta

Expand All @@ -157,6 +163,12 @@ def lift(
node.replace_all_uses_with(const_placeholder_node)
gm.graph.erase_node(node)

# Verify if the const_placeholder being added is one of the output nodes
# This happens if there is just a single static arange op in the graph
# https://github.com/pytorch/TensorRT/issues/3189
if const_placeholder_name in output_names:
output_names[const_placeholder_name] = const_placeholder_node.name

# Add these parameters/buffers/constants to the existing graph signature
# before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
Expand All @@ -174,6 +186,11 @@ def lift(
)
non_user_input_idx += 1

# Update output_specs with modified names. This only gets updated if the graph getattr nodes (weights)
# are also the outputs of the graph
for output_spec in graph_signature.output_specs:
output_spec.arg.name = output_names[output_spec.arg.name]

gm.graph.eliminate_dead_code()
gm.graph.lint()

Expand Down
58 changes: 58 additions & 0 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,64 @@ def forward(self, x):
)


@pytest.mark.unit
def test_arange_export(ir):
"""
This tests export save and load functionality on a arange static graph
Here the arange output is a static constant (which is registered as input to the graph)
in the exporter.
"""

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x_embed = torch.arange(
1, x.shape[-1] + 1, dtype=torch.float32, device=x.device
)
return x_embed

model = MyModule().eval().cuda()
input = torch.randn((1, 1, 128, 128)).to("cuda")

compile_spec = {
"inputs": [
torchtrt.Input(
input.shape, dtype=torch.float, format=torch.contiguous_format
)
],
"ir": ir,
"min_block_size": 1,
"cache_built_engines": False,
"reuse_cached_engines": False,
}

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

torchtrt.save(trt_module, trt_ep_path, inputs=[input])

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)

for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_arange_export TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

outputs_trt_deser = deser_trt_module(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_arange_export deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
def test_save_load_ts(ir):
"""
Expand Down
Loading