Skip to content

Commit 28a2e35

Browse files
committed
fix: Fix static arange export
1 parent 43eb560 commit 28a2e35

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

py/torch_tensorrt/dynamo/_exporter.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,13 @@ def lift(
9898
)
9999
assert fake_mode is not None
100100

101+
# This map stores the names of outputs (old to new)
102+
# This is necessary to track because the output names can be changed when
103+
# we convert graph constants to placeholder inputs below.
104+
output_names = {}
105+
for output_spec in graph_signature.output_specs:
106+
output_names[output_spec.arg.name] = output_spec.arg.name
107+
101108
# Locate the user input to insert new placeholders before them
102109
first_user_input = None
103110
for node in gm.graph.nodes:
@@ -139,9 +146,8 @@ def lift(
139146
# Replace get_attr nodes with placeholder nodes and copy metadata.
140147
with gm.graph.inserting_before(first_user_input):
141148
# Ensure name doesn't contain period as it is used for submodules
142-
const_placeholder_node = gm.graph.placeholder(
143-
node.target.replace(".", "_")
144-
)
149+
const_placeholder_name = node.target.replace(".", "_")
150+
const_placeholder_node = gm.graph.placeholder(const_placeholder_name)
145151
# Copy the node meta into this new placeholder node
146152
const_placeholder_node.meta = node.meta
147153

@@ -157,6 +163,12 @@ def lift(
157163
node.replace_all_uses_with(const_placeholder_node)
158164
gm.graph.erase_node(node)
159165

166+
# Verify if the const_placeholder being added is one of the output nodes
167+
# This happens if there is just a single static arange op in the graph
168+
# https://github.com/pytorch/TensorRT/issues/3189
169+
if const_placeholder_name in output_names:
170+
output_names[const_placeholder_name] = const_placeholder_node.name
171+
160172
# Add these parameters/buffers/constants to the existing graph signature
161173
# before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
162174
input_spec_arg = TensorArgument(name=const_placeholder_node.name)
@@ -174,6 +186,11 @@ def lift(
174186
)
175187
non_user_input_idx += 1
176188

189+
# Update output_specs with modified names. This only gets updated if the graph getattr nodes (weights)
190+
# are also the outputs of the graph
191+
for output_spec in graph_signature.output_specs:
192+
output_spec.arg.name = output_names[output_spec.arg.name]
193+
177194
gm.graph.eliminate_dead_code()
178195
gm.graph.lint()
179196

tests/py/dynamo/models/test_export_serde.py

+58
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,64 @@ def forward(self, x):
381381
)
382382

383383

384+
@pytest.mark.unit
385+
def test_arange_export(ir):
386+
"""
387+
This tests export save and load functionality on a arange static graph
388+
Here the arange output is a static constant (which is registered as input to the graph)
389+
in the exporter.
390+
"""
391+
392+
class MyModule(torch.nn.Module):
393+
def __init__(self):
394+
super().__init__()
395+
396+
def forward(self, x):
397+
x_embed = torch.arange(
398+
1, x.shape[-1] + 1, dtype=torch.float32, device=x.device
399+
)
400+
return x_embed
401+
402+
model = MyModule().eval().cuda()
403+
input = torch.randn((1, 1, 128, 128)).to("cuda")
404+
405+
compile_spec = {
406+
"inputs": [
407+
torchtrt.Input(
408+
input.shape, dtype=torch.float, format=torch.contiguous_format
409+
)
410+
],
411+
"ir": ir,
412+
"min_block_size": 1,
413+
"cache_built_engines": False,
414+
"reuse_cached_engines": False,
415+
}
416+
417+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
418+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
419+
420+
torchtrt.save(trt_module, trt_ep_path, inputs=[input])
421+
422+
deser_trt_module = torchtrt.load(trt_ep_path).module()
423+
outputs_pyt = model(input)
424+
outputs_trt = trt_module(input)
425+
426+
for idx in range(len(outputs_pyt)):
427+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
428+
assertions.assertTrue(
429+
cos_sim > COSINE_THRESHOLD,
430+
msg=f"test_arange_export TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
431+
)
432+
433+
outputs_trt_deser = deser_trt_module(input)
434+
for idx in range(len(outputs_pyt)):
435+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
436+
assertions.assertTrue(
437+
cos_sim > COSINE_THRESHOLD,
438+
msg=f"test_arange_export deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
439+
)
440+
441+
384442
@pytest.mark.unit
385443
def test_save_load_ts(ir):
386444
"""

0 commit comments

Comments
 (0)