@@ -98,6 +98,13 @@ def lift(
98
98
)
99
99
assert fake_mode is not None
100
100
101
+ # This map stores the names of outputs (old to new)
102
+ # This is necessary to track because the output names can be changed when
103
+ # we convert graph constants to placeholder inputs below.
104
+ output_names = {}
105
+ for output_spec in graph_signature .output_specs :
106
+ output_names [output_spec .arg .name ] = output_spec .arg .name
107
+
101
108
# Locate the user input to insert new placeholders before them
102
109
first_user_input = None
103
110
for node in gm .graph .nodes :
@@ -139,9 +146,8 @@ def lift(
139
146
# Replace get_attr nodes with placeholder nodes and copy metadata.
140
147
with gm .graph .inserting_before (first_user_input ):
141
148
# Ensure name doesn't contain period as it is used for submodules
142
- const_placeholder_node = gm .graph .placeholder (
143
- node .target .replace ("." , "_" )
144
- )
149
+ const_placeholder_name = node .target .replace ("." , "_" )
150
+ const_placeholder_node = gm .graph .placeholder (const_placeholder_name )
145
151
# Copy the node meta into this new placeholder node
146
152
const_placeholder_node .meta = node .meta
147
153
@@ -157,6 +163,12 @@ def lift(
157
163
node .replace_all_uses_with (const_placeholder_node )
158
164
gm .graph .erase_node (node )
159
165
166
+ # Verify if the const_placeholder being added is one of the output nodes
167
+ # This happens if there is just a single static arange op in the graph
168
+ # https://github.com/pytorch/TensorRT/issues/3189
169
+ if const_placeholder_name in output_names :
170
+ output_names [const_placeholder_name ] = const_placeholder_node .name
171
+
160
172
# Add these parameters/buffers/constants to the existing graph signature
161
173
# before user inputs. These specs are looked up in the state_dict during ExportedProgram creation.
162
174
input_spec_arg = TensorArgument (name = const_placeholder_node .name )
@@ -174,6 +186,11 @@ def lift(
174
186
)
175
187
non_user_input_idx += 1
176
188
189
+ # Update output_specs with modified names. This only gets updated if the graph getattr nodes (weights)
190
+ # are also the outputs of the graph
191
+ for output_spec in graph_signature .output_specs :
192
+ output_spec .arg .name = output_names [output_spec .arg .name ]
193
+
177
194
gm .graph .eliminate_dead_code ()
178
195
gm .graph .lint ()
179
196
0 commit comments