|
| 1 | +import torch |
| 2 | +from torch.fx.node import _get_qualified_name |
| 3 | +from typing import Optional, Sequence, Union |
| 4 | + |
| 5 | + |
| 6 | +def _extract_downstream_get_nodes( |
| 7 | + module_node: torch.fx.Node, output_indices: Sequence[int] |
| 8 | +) -> Sequence[torch.fx.Node]: |
| 9 | + """Extracts downstream users of a node which get the item at a particular index |
| 10 | +
|
| 11 | + Certain module-type nodes have multiple outputs (tuple of outputs). This function |
| 12 | + returns downstream nodes which call the _operator.getitem function, which extracts |
| 13 | + the element at a particular index in the tuple |
| 14 | +
|
| 15 | + Args: |
| 16 | + module_node: FX module-type node to analyze |
| 17 | + output_index: Indices in the module node output to search for |
| 18 | + Returns: |
| 19 | + List of nodes which get the item at the specified index in the module node output |
| 20 | + """ |
| 21 | + get_nodes = [] |
| 22 | + |
| 23 | + # Iterate over all downstream users of the node object |
| 24 | + for user in module_node.users: |
| 25 | + # If the user is a "get" node accessing the specified index, store it |
| 26 | + if _get_qualified_name(user.target) == "_operator.getitem" and ( |
| 27 | + user.args[1] in output_indices |
| 28 | + ): |
| 29 | + get_nodes.append(user) |
| 30 | + |
| 31 | + return get_nodes |
| 32 | + |
| 33 | + |
| 34 | +def _repair_64bit_input( |
| 35 | + gm: torch.fx.GraphModule, |
| 36 | + position: int, |
| 37 | + submodule_name: str, |
| 38 | + submodule_outputs: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], |
| 39 | + dtype: torch.dtype, |
| 40 | +): |
| 41 | + """Fixes a single Long/Double input to a TRT-accelerated subgraph |
| 42 | +
|
| 43 | + In-Place modifies the provided graph |
| 44 | +
|
| 45 | + Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, |
| 46 | + inserts an upcast back to the 64-bit type for subsequent Torch operations |
| 47 | +
|
| 48 | + Args: |
| 49 | + gm: FX GraphModule enclosing the TRT subgraph |
| 50 | + position: Index in the submodule inputs at which the long or double input is found |
| 51 | + submodule_name: Name of TRT-accelerated subgraph module in FX graph |
| 52 | + submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) |
| 53 | + dtype: Data type of tensor at position in submodule (double/long) |
| 54 | + """ |
| 55 | + assert dtype in ( |
| 56 | + torch.int64, |
| 57 | + torch.float64, |
| 58 | + ), f"dtype argument must be torch.int64 or torch.float64, got {dtype}" |
| 59 | + |
| 60 | + # Determine target data type in 32 and 64 bit forms |
| 61 | + dtype_64bit = dtype |
| 62 | + dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32 |
| 63 | + |
| 64 | + # Find the node representing the submodule in the graph |
| 65 | + module_node = None |
| 66 | + |
| 67 | + # Iterate over all nodes in the graph, seeking target module name match |
| 68 | + for n in gm.graph.nodes: |
| 69 | + if n.op == "call_module" and str(n.target) == submodule_name: |
| 70 | + module_node = n |
| 71 | + break |
| 72 | + |
| 73 | + if module_node is None: |
| 74 | + raise AssertionError( |
| 75 | + f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}" |
| 76 | + ) |
| 77 | + |
| 78 | + # Extract the 64-bit node of the input |
| 79 | + node_64bit = module_node.all_input_nodes[position] |
| 80 | + |
| 81 | + # Prior to the module, insert a cast to the 32-bit equivalent node |
| 82 | + with gm.graph.inserting_before(module_node): |
| 83 | + node_32bit = gm.graph.call_function( |
| 84 | + torch.ops.aten._to_copy.default, |
| 85 | + args=(node_64bit,), |
| 86 | + kwargs={"dtype": dtype_32bit}, |
| 87 | + ) |
| 88 | + |
| 89 | + # Replace 64-bit input to TRT module with new 32-bit cast node |
| 90 | + module_node.replace_input_with(node_64bit, node_32bit) |
| 91 | + |
| 92 | + output_positions_64bit = set() |
| 93 | + outputs_list = ( |
| 94 | + [submodule_outputs] |
| 95 | + if isinstance(submodule_outputs, torch.Tensor) |
| 96 | + else submodule_outputs |
| 97 | + ) |
| 98 | + |
| 99 | + # Determine if any outputs of the model are 64-bit type and store their indices |
| 100 | + if submodule_outputs is not None: |
| 101 | + for output_position, output in enumerate(outputs_list): |
| 102 | + if output.dtype == dtype_64bit: |
| 103 | + output_positions_64bit.add(output_position) |
| 104 | + |
| 105 | + # Only enter this code block if there exists a 64-bit output |
| 106 | + # This implies a cast is needed, since TRT cannot output 64-bit tensors |
| 107 | + if output_positions_64bit: |
| 108 | + # Determine whther the outputs of the module are tuple-type or not |
| 109 | + is_collection_output = False |
| 110 | + if isinstance(submodule_outputs, tuple): |
| 111 | + is_collection_output = True |
| 112 | + |
| 113 | + if not is_collection_output: |
| 114 | + # If the output is a single tensor, insert a cast back to int64 |
| 115 | + with gm.graph.inserting_after(module_node): |
| 116 | + cast_node_64bit = gm.graph.call_function( |
| 117 | + torch.ops.aten._to_copy.default, |
| 118 | + args=(module_node,), |
| 119 | + kwargs={"dtype": dtype_64bit}, |
| 120 | + ) |
| 121 | + |
| 122 | + # Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent |
| 123 | + module_node.replace_all_uses_with( |
| 124 | + cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit) |
| 125 | + ) |
| 126 | + |
| 127 | + else: |
| 128 | + # If the output is a tuple of tensors, extract downstream users for each 64-bit output |
| 129 | + get_nodes = _extract_downstream_get_nodes( |
| 130 | + module_node, output_positions_64bit |
| 131 | + ) |
| 132 | + |
| 133 | + # For each downstream user, append a cast node back to the 64-bit precision |
| 134 | + for get_node in get_nodes: |
| 135 | + with gm.graph.inserting_after(get_node): |
| 136 | + cast_node_64bit = gm.graph.call_function( |
| 137 | + torch.ops.aten._to_copy.default, |
| 138 | + args=(get_node,), |
| 139 | + kwargs={"dtype": torch.int64}, |
| 140 | + ) |
| 141 | + |
| 142 | + get_node.replace_all_uses_with( |
| 143 | + cast_node_64bit, |
| 144 | + delete_user_cb=lambda user: (user != cast_node_64bit), |
| 145 | + ) |
| 146 | + |
| 147 | + # Clean up graph and ensure invariants are preserved |
| 148 | + gm.graph.eliminate_dead_code() |
| 149 | + gm.graph.lint() |
| 150 | + gm.recompile() |
| 151 | + |
| 152 | + |
| 153 | +def repair_long_or_double_inputs( |
| 154 | + parent_graph: torch.fx.GraphModule, |
| 155 | + submodule: torch.fx.GraphModule, |
| 156 | + submodule_inputs: Sequence[torch.Tensor], |
| 157 | + submodule_name: Optional[str] = None, |
| 158 | +) -> Sequence[torch.Tensor]: |
| 159 | + """Fixes all Long/Double type inputs to a TRT-accelerated subgraph |
| 160 | +
|
| 161 | + In-Place modifies the provided graph |
| 162 | +
|
| 163 | + Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, |
| 164 | + inserts an upcast back to the 64-bit type for subsequent Torch operations |
| 165 | +
|
| 166 | + Args: |
| 167 | + parent_graph: FX GraphModule enclosing the TRT subgraph |
| 168 | + submodule: Child submodule to repair inputs on |
| 169 | + submodule_inputs: Input tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) |
| 170 | + submodule_name: Optionally specify the name of the submodule target in the parent graph |
| 171 | + Returns: |
| 172 | + New submodule inputs, updated accordingly with long/double truncation |
| 173 | + """ |
| 174 | + num_submodule_inputs = len(submodule_inputs) |
| 175 | + repaired_outputs_once = False |
| 176 | + |
| 177 | + # For each input to the TRT subgraph, check if its type is long/double |
| 178 | + for position in range(num_submodule_inputs): |
| 179 | + param = submodule_inputs[position] |
| 180 | + |
| 181 | + # If the data type of the input is long/double, insert necessary |
| 182 | + # casts to replace the operation |
| 183 | + if param.dtype in (torch.int64, torch.float64): |
| 184 | + # Ensure outputs are only repaired once per submodule to avoid |
| 185 | + # unnecessary ops showing up in the graph |
| 186 | + if not repaired_outputs_once: |
| 187 | + submodule_outputs = submodule(*submodule_inputs) |
| 188 | + |
| 189 | + _repair_64bit_input( |
| 190 | + parent_graph, |
| 191 | + position, |
| 192 | + submodule_name if submodule_name is not None else submodule._get_name(), |
| 193 | + None if repaired_outputs_once else submodule_outputs, |
| 194 | + param.dtype, |
| 195 | + ) |
| 196 | + |
| 197 | + repaired_outputs_once = True |
| 198 | + |
| 199 | + # Repair submodule inputs in accordance with inserted casts |
| 200 | + dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32 |
| 201 | + submodule_inputs = ( |
| 202 | + submodule_inputs[:position] |
| 203 | + + (param.to(dtype_32bit),) |
| 204 | + + submodule_inputs[position + 1 :] |
| 205 | + ) |
| 206 | + |
| 207 | + return submodule_inputs |
0 commit comments