Skip to content

Commit eaf2bb4

Browse files
{{wushirong}}narendasan
{{wushirong}}
authored andcommitted
Changes done internally at Facebook
483102cd4151f02c2d3632e6b6df7a5e59c0d6f3 Wei Wei <[email protected]> [fx2trt] move acc op `torch.ops._caffe2.RoIAlign` to fb only 8ce94a01caa090d56adb4708452b52890160ba69 Wei Wei <[email protected]> [aten2trt] reshape support 422326213bad177019e92c95dbc61af7a427bebc Shirong Wu <[email protected]> nan_to_num aten converter f729c8a7f1268f329d15e3cf05f1fb9232fab2d9 Huamin Li <[email protected]> Record TRT/AIT lower context into Scuba gpu_lowering_diagnostics 2df64af6bcf102a0ce40f1c5ab8472370d012904 Wei Wei <[email protected]> [aten2ait][fx2ait] sin,cos,sqrt,clone support 9fa6469ccb9d00320d78684d748fe1a7e5c3cf60 Janet Yang <[email protected]> Split nodes w/ float64 inputs from lowering d2ea242f721156df9e075927ea7956db772d4107 Fei Kou <[email protected]> Handle Ellipsis in dper passes d053b097a0d1c158cde29792a35c4ec4174d9417 Jason Ansel <[email protected]> Fix tests broken by D42953629 e18c6c76b1678a95c35583dabb41666b33c3df63 Zhijing Li (Accelerator Enablement) <[email protected]> Add dper test for push_down_split pass 5008c6d200f2a9ca035547204b47eb5e1704ce88 Zhijing Li (Accelerator Enablement) <[email protected]> Add passes as option to AITTestCase.run_test f7bc0c543b553ca2f80149995b4c28599a6ea396 Ying Zhang <[email protected]> Back out "Add passes as option to AITTestCase.run_test" 22d4044c66720e0e656f41538c81a3e90ef1a433 Zhijing Li (Accelerator Enablement) <[email protected]> Relaunch add passes as option to AITTestCase.run_test ae0de22b6a97bca82c0ef6a14b0be2b570eb443a Eli Uriegas <[email protected]> Remove fx2trt/torch2trt backends (#93822) b08e568951c911e4c3bbc72b55830fa1d4be4b2b Eli Uriegas <[email protected]> Remove torch/_dynamo/optimizations (#93871) 725266c0b7eb0549060e79b65346d703cc5bc39e Benson Ma <[email protected]> [T143761882] Migrate CUTLASS group gemm operators to OSS 44110f5df422e84cd9d9afbf5dfbe742057b9d98 Zhijing Li (Accelerator Enablement) <[email protected]> Add noop pass for torch.ops.fb.scale_gradient 84befb25b778485c8694ba659248d4d570d92390 Chao Gu <[email protected]> [FX] Add log_softmax b641713bd774cb7c7bf903f514bff5c87a6f3a33 Wei Wei <[email protected]> [fx2ait] support torch.unbind, torch.group_norm d263b38b53b93a18a78cd34b2a1c48711c3c59cd Shirong Wu <[email protected]> Add extra logging for layer norm eb2591231195cc0ab6780f345f095807a7d45f7c Callum Ryan <[email protected]> Make GPU test run in bundled mode f63d3834e87a819f8335c50b351e48f60573d474 Sarunya Pumma <[email protected]> Back out "[T143761882] Migrate CUTLASS group gemm operators to OSS" a9f489c1c3a182698385053c0a94b792c4e310ba Shirong Wu <[email protected]> Change opt_profile_replica to 3 b8bdde86f0bae6010062c33aec03a4e13a87a6ab Brian Hirsh <[email protected]> forward fix for new _native_batch_norm_legit_no_training op e8f4cbd46402e5603cc48d24395db3f0e010581a Shirong Wu <[email protected]> Fix reshape op b860725bfaf74a0043190d1140ddee987dd82d0c generatedunixname89002005232357 <[email protected]> Revert D43278214: Multisect successfully blamed D43278214 for test or build failures d4ea365cf8aa56d752912f7878b8046e89c804c2 Chunxing Yin <[email protected]> [mtia] Add sigmoid_backward kernel binding in Glow a768c82a51a058e56a64ff82f90e619795611b66 Mor Tzur <[email protected]> lower to ait 8eb52426aaca586ae50fde75cccca6a0827a8328 Wei Wei <[email protected]> [hstu][fx2ait] op support 55d95ffa096d9de7952a6a1c4628efd67e554d82 Wei Wei <[email protected]> [fx2ait] temp solution to set correct dynamic batch size for jagged tensor 0a42e2f0874c48e9b60503a25705f0fc6319ff87 Jia Jiunn Ang <[email protected]> [CMF] chunk-squeeze-cat op fusion when split on last dimension 8bd509596a799f1270796772e12be090a6db5d39 Wei Wei <[email protected]> [aten2trt] update comment 1761b440d646836116fdadf2b5c7b55c7d2b989b Oleg Khabinov <[email protected]> [fx2ait] Fix a dper pass when acc_ops.squeeze doesn't have a dim 3cc405a92c9fcec886d890de87ac94e024c682a5 Jia Jiunn Ang <[email protected]> [CMF] Fuse chunk-linear-cat into baddbmm 5f42f56c5b5d0bd4c058aa280a980e64dd89b0a9 Xiaodong Wang <[email protected]> [cudnn] static linking 229969542a2c1e96fe8345ff7adc2fd48f6a0707 Romain Sauvestre <[email protected]> Remove base_module from acc_tracer target a174195c484d5a25f06e4c0665bbb2e9d9dcae82 Janet Yang <[email protected]> Support input_tensor_spec w/ multiple batch dims in TRT 0246365e6facc6dfb13843fa9854802f35c0938a Zhijing Li (Accelerator Enablement) <[email protected]> Remove noop dropout op with acc tracer 4c287b9f6238e8bbbd80e742262a0eee6efa57de Kunming Ho <[email protected]> Operator support for threshold_backward 71bb34c81289173b83c7e7cf544b851096d9d99d Fei Kou <[email protected]> specialize_int_float to specialize_int from D43925225 037db53f89a7b863ef0fbaa7b94425fd9a08dc96 Wei Wei <[email protected]> enable torchscripting 77f3dce76fd5407b08826f67213d8299d9d48542 Adnan Akhundov <[email protected]> [fx2ait] Extend jagged tensor support e6b551e48a0c03db63fc46ff85d975b489e30079 Jordan Fix <[email protected]> [acc_tracer] Add dont_retrace_gm option ada3cbbb3d6c3b3631496a3bceea775f45649c6c Adam Simpkins <[email protected]> Fix a bunch of invalid Python escape warnings in torch_tensorrt 98254d631e8748a85b05851c97fb74f3e3922cfe Brandon Tran (Realtime Integrity) <[email protected]> Add torch.nn.functional.normalize to TensorRT fce21e2248ad0fddfcc12dbe2e3a9a6ac9ea2a5f Shirong Wu <[email protected]> Fix trt input spec a08bad1ac74a6d1409bb3f2e96953ed0c149d006 Wei Wei <[email protected]> [fx2ait] changes to improve jagged tensor and add b2b bmm 7745d70a17677777dcb5806e1e8008532f961f5d generatedunixname485339166882981 <[email protected]> [Codemod][[pyunit][static_listing] Convert python unit test dynamic listing to static listing] oncall+gpu_enablement_0 ba33951ae2d2ebc99794aff8026a01a31f9ad8da Shirong Wu <[email protected]> Add ait full op converters b3bfd69f15fc4e32f27217a3efa8204a2f062af8 Chao Gu <[email protected]> [FX] support index_add in acc ops and tracer a965bafc517afc81591052e355fd34062b028a89 Shirong Wu <[email protected]> Make fill op read dtype from input/kwarg 72f9b0925eceffc12dfa51769c1bd0cb38a3e50c generatedunixname485339166882981 <[email protected]> [Codemod][[pyunit][static_listing] Convert python unit test dynamic listing to static listing] oncall+gpu_enablement 2e7feece191d6178ff6ec750d8fe481175bb27b9 Max Podkorytov <[email protected]> [fx2ait] enable lowering to bfloat16 94607911ffb11e78082e061a670b5140e9a55d72 Archie Sravankumar <[email protected]> Add support for nan_to_num 42fddd20d303dbbc3355a8c09a86d4a74317be97 Max Podkorytov <[email protected]> [AITemplate] feed_lower_benchmark cli argument tweak for choosing precision 648ec682f2214e67912fe7c800f7ca059195cf4e Huamin Li <[email protected]> Re-enable previous disabled TRT unit tests 3e5c2aac8a7b9e50efe04fcae361a3c0ee1777a7 Janet Yang <[email protected]> Skip acc normalization of repeat_interleave if input dims aren't integral f412f35baeee9a1b17f67b7749ca1f9b8cbbe77b Janet Yang <[email protected]> Skip acc normalization of repeat if dims aren't ints 5b9cfe428f29e27da76b19029bda03a8b43c17d1 Huamin Li <[email protected]> add import into generate_standalone_repro 9f88965e87e72658aa6a4973dc870d50b8a22ca4 Fei Kou <[email protected]> lowering with bf16 7f761df34d672c87c40b18369b28bc593374122c Fei Kou <[email protected]> [benchmark] Support bfloat16 in mts_gpu_benchmark fa9b09e11ba8f888d761e1398367973d30e0aa1e Wei Wei <[email protected]> [fx2ait] add a simple eager run to verify the input generatation is correct 4f8ca36dbdc72dfa60e667c3592d0a2bc466b994 Max Podkorytov <[email protected]> [AITemplate] implement per op benchmark backbone 9873be1e82f2dd4a8a768497ac9cdb3b9b95cfe9 Thomas Orozco <[email protected]> buck2: tag more long running tests 0d6827c464aa2141a48a8d768a8c7facd65c0bc4 generatedunixname485339166882981 <[email protected]> [Codemod][[pyunit][static_listing] Convert python unit test dynamic listing to static listing] oncall+gpu_enablement_0_2ea3 04f9c1105a2a6a711d025d5c85b95147343d0ecd Zhijing Li (Accelerator Enablement) <[email protected]> [fx2ait] Fix acc_ops converter on std when keepdim=False 906bad1deebb235a9c80d0f0d46145da08afa091 Danylo Baibak <[email protected]> Forward fix to switch calling convention back to real tensors 48ffa2ab3dd66487922f9f0bf9a145db6eaf3fe2 Kefei Lu <[email protected]> Lowering: validate inference with alternative batch sizes ca5dc1a2896bd476e3a327db834df859a3fcc11f Jordan Fix <[email protected]> [fba_pass_manager_builder][BE] General cleanup/refactor afb4df5e84571f466b0f385472493aefb89344cc Shirong Wu <[email protected]> Mask select converter 25e8afb1f8be19ec6c4ef4bc74ea48e64017cde2 Janet Yang <[email protected]> Fix lowering FusedTsEncodingModule for coffee model 7fdf06ecfc6b4efb7008ce399dcd0c32ef1f1f75 generatedunixname485339166882981 <[email protected]> [Codemod][[pyunit][static_listing] Convert python unit test dynamic listing to static listing] oncall+deprecated_techdebt_do_not_use_4_9c34 a58c5e454412585c4cc48ced1798dbf234cc13b6 Michael Liu <[email protected]> Initialize `output_count` in `get_model_info_str` 2c6f13ddcc52e8f833fcd164d0c479ca3398322e Wei Wei <[email protected]> jagged SHA and MHA module support 2fe5c7cd3b763b839af3d1b05eecc73f1df05286 Shirong Wu <[email protected]> Add BF16 support for ads model 2486edbe5013f3b7e5807503538f3164bdd4ee19 Shirong Wu <[email protected]> Add low_level_module conversion pass ca7c51407ab0410d311c984b31aeb757dd840bc2 Wei Wei <[email protected]> [hstu] remove torch_package from RelativeBucketedTimeAndPositionBasedBias after packaged 80596e459343d5630e16a6175eafffd2c25a3123 Shirong Wu <[email protected]> Block a pass that yield problem ded609195500a8edc5bed80ee85f41b35224c19f Huamin Li <[email protected]> Do not test test_implicit_batch_dim if >= 8.6 8e8e736e14d23e77fa2bd5e72123d66943f7716f Huamin Li <[email protected]> Speed up TRT compile time in test env 2db82572e509cfe827c34a4060c058ae44b5547a Jordan Fix <[email protected]> [acc_tracer] Add in use_concrete_args option to enable pytree flatten/unflatten 946f957b6636c6b4f64e52148c9baf6e0351fb5e Wei Wei <[email protected]> [hstu] changes to bias module and sha/mha pass to adapt to removing presences d904b26386c2ef98f292edae7c5e98c27119f9d9 Oleg Khabinov <[email protected]> [fx2ait] Rename split_op_to_reshape to chunk_op_to_reshape ca36733f0ea67aeeb38a3740f795bbf99b24037b Oleg Khabinov <[email protected]> [fx2ait] Rewrite chunk_op_to_reshape() to use while loop instead of recursion 4361feb4399eec3816b534991020703d099d2896 Oleg Khabinov <[email protected]> [fx2ait] Optimize chunk_op_to_reshape() 071b84e3cda4f0175b37ae62c37b2d4f2de7925f Huamin Li <[email protected]> Disable libkineto for TRT unit tests 92f9acaac8f9a8f0fc2e1382bf4c79d0b94cbea5 Wei Wei <[email protected]> [fx2ait] improve bf16 support 8b92e8356278eb9676a5299373841593af942fb4 Jongsoo Park <[email protected]> [acc_tracer] skip None module in rewriting 0d1d644bad22c86efec12009ca1464587d1e7d38 Kefei Lu <[email protected]> Remove non-existent argument doc string 2efe5e78bc8627a30ba132e5b8e14e06538d463f shirong <[email protected]> Temp fix a15a564a567eb689604d27ca814553e38c287698 shirong <[email protected]> Temporary commit at 4/24/2023, 2:32:22 PM 78825462243c09760ebb73156a4c18bbc9ddee75 shirong <[email protected]> Temporary commit at 4/24/2023, 2:32:37 PM 9bfea274462fd77cb04c38c17bc237541af87c55 laksrm <[email protected]> [DNR] onboard ctr to aimp with lowering fixes 8bb482b10f7f63270c329c88d5ac028b40f6b757 shirong <[email protected]> Reenable pass
1 parent 81adf2b commit eaf2bb4

26 files changed

+1193
-197
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,13 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
701701
eps_field = trt.PluginField(
702702
"eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
703703
)
704+
normalized_shape = kwargs["normalized_shape"]
704705
try:
705-
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
706+
normalized_shape = np.array(normalized_shape, dtype=np.int32)
706707
except TypeError:
707-
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
708+
_LOGGER.error(
709+
f"Unable to convert normalized_shape with value {normalized_shape} to a field, fall back to []"
710+
)
708711
normalized_shape = np.array([], dtype=np.int32)
709712

710713
normalized_shape_filed = trt.PluginField(

py/torch_tensorrt/fx/diagnostics.py

+8
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,14 @@ class DiagnosticsWriter:
8787

8888
def __init__(self):
8989
self._root_dir = tempfile.mkdtemp(prefix="fx2trt.")
90+
self._data = ""
9091
_LOGGER.info(f"Initializing DiagnosticsWriter with root_dir: {self._root_dir}")
9192

9293
def write(self, file_name: str, data: WriteObj):
9394
"""
9495
TODO: Can be disabled by regex on file_name
9596
"""
97+
self._data = data
9698
# Only write if we are inside a collect_when() context.
9799
if not _IS_IN_COLLECT_CONTEXT.get(False):
98100
return
@@ -117,6 +119,9 @@ def write(self, file_name: str, data: WriteObj):
117119
def root_dir(self) -> str:
118120
return self._root_dir
119121

122+
def data(self) -> WriteObj:
123+
return self._data
124+
120125
def _write(self, file_name: str, to_write: bytes):
121126
# ms granularity - no naming collash, otherwise file will be
122127
# overwritten.
@@ -271,6 +276,9 @@ def collect(self) -> str:
271276
finally:
272277
os.remove(fp)
273278

279+
def data(self) -> WriteObj:
280+
return self._write.data()
281+
274282

275283
def _res_or_err(data: WriteObj) -> t.Tuple[TWrite, str]:
276284
if isinstance(data, (str, bytes)):

py/torch_tensorrt/fx/fx2trt.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
import warnings
34
from datetime import datetime
45
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
@@ -211,6 +212,11 @@ def run(
211212
builder_config = self.builder.create_builder_config()
212213
builder_config.max_workspace_size = max_workspace_size
213214

215+
# Speed up TRT build time in the test environment
216+
if trt.__version__ >= "8.6" and os.environ.get("TRT_TEST_ENV", "0") == "1":
217+
_LOGGER.info("Set TRT optimization level to 0")
218+
builder_config.builder_optimization_level = 0
219+
214220
cache = None
215221
if timing_cache:
216222
cache_file = numpy.array(timing_cache)

py/torch_tensorrt/fx/input_tensor_spec.py

+62-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple
1+
from typing import Any, Iterable, List, NamedTuple, Optional, Sequence, Tuple
22

33
import torch
44

@@ -18,6 +18,12 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
1818
# is the dynamic batch dimension. Otherwise, we use the additional
1919
# inputs to determine the batch dimension.
2020
if additional_inputs is None:
21+
batch_dims = None
22+
if not isinstance(inputs, torch.Tensor) and len(inputs) > 1:
23+
bs = inputs[0].size(0)
24+
batch_dims = None
25+
if not all(x.size(0) == bs for x in inputs):
26+
batch_dims = InputTensorSpec.find_batch_size_dim(inputs)
2127
return InputTensorSpec.from_tensors_with_dynamic_batch_size(
2228
inputs,
2329
(
@@ -26,6 +32,7 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
2632
lower_setting.max_batch_size,
2733
),
2834
lower_setting.opt_profile_replica,
35+
batch_dims,
2936
)
3037
else:
3138
batch_dims = []
@@ -147,25 +154,69 @@ def from_tensors_with_dynamic_batch_size(
147154
A list of InputTensorSpec named tuples with dynamic ranges.
148155
"""
149156
if batch_dims is None:
150-
batch_dims = [0] * len(tensors)
157+
batch_dims = cls.find_batch_size_dim(tensors)
151158

152159
input_specs = []
153160
batch_size = tensors[0].size(batch_dims[0])
154161

155162
for i, tensor in enumerate(tensors):
156163
batch_dim = batch_dims[i]
157-
assert batch_size == tensor.size(
158-
batch_dim
159-
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
160-
shape = list(tensor.shape)
161-
shape[batch_dim] = -1
162-
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
163-
input_specs.append(
164-
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
165-
)
164+
if batch_dim == -1:
165+
input_specs.append(cls.from_tensor(tensor))
166+
else:
167+
shape = list(tensor.shape)
168+
assert batch_size == tensor.size(
169+
batch_dim
170+
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
171+
shape[batch_dim] = -1
172+
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
173+
input_specs.append(
174+
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
175+
)
166176

167177
return input_specs
168178

179+
@classmethod
180+
# pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any`
181+
def find_batch_size_dim(cls, inputs: Any) -> []:
182+
if isinstance(inputs, torch.Tensor) or len(inputs) <= 1:
183+
return [0]
184+
shapes = [i.shape for i in inputs]
185+
frequency_map = {}
186+
first_dims = set()
187+
for shape in shapes:
188+
if len(shape) < 2:
189+
# By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info
190+
continue
191+
# Dedup shape value for single tensor
192+
first_dims.add(shape[0])
193+
shape = set(shape)
194+
for i in shape:
195+
frequency_map[i] = frequency_map.get(i, 0) + 1
196+
197+
if len(first_dims) == 1:
198+
# first dim is the same in every input: we use it as batch_size
199+
batch_size = first_dims.pop()
200+
elif frequency_map:
201+
# first dims are different: we use the most frequent dim as batch_size
202+
sorted_frequency = sorted(frequency_map.items(), key=lambda x: -x[1])
203+
batch_size = sorted_frequency[0][0]
204+
else:
205+
# no dims to sort: no batch_size
206+
batch_size = -1
207+
208+
bs_dim = []
209+
for i in inputs:
210+
# Default batch size dim = -1, indicate no batch_size
211+
dim = -1
212+
for index, val in enumerate(i.shape):
213+
if val == batch_size:
214+
dim = index
215+
break
216+
bs_dim.append(dim)
217+
218+
return bs_dim
219+
169220
def to_random_tensor(self, id=1):
170221
shape = tuple(self.shape)
171222
if len(get_dynamic_dims(shape)):

py/torch_tensorrt/fx/lower.py

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def compile(
4141
dynamic_batch=True,
4242
is_aten=False,
4343
use_experimental_fx_rt=False,
44+
correctness_atol=1e-1,
45+
correctness_rtol=1e-1,
4446
) -> nn.Module:
4547
"""
4648
Takes in original module, input and lowering setting, run lowering workflow to turn module
@@ -81,6 +83,8 @@ def compile(
8183
dynamic_batch=dynamic_batch,
8284
is_aten=is_aten,
8385
use_experimental_rt=use_experimental_fx_rt,
86+
correctness_atol=correctness_atol,
87+
correctness_rtol=correctness_rtol,
8488
)
8589
lowerer = Lowerer.create(lower_setting=lower_setting)
8690
return lowerer(module, input)

py/torch_tensorrt/fx/passes/lower_basic_pass.py

+41-5
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,14 @@ def fill_with_mul_zero_and_add(*args):
5454

5555

5656
def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule:
57-
# Now we do constant folding on traced module. We want to skip pattern like
58-
# weights -> quant -> dequant -> op during constant folding when the model is
59-
# a quantized int8 model.
60-
def skip_folding_quant_dequant(node: torch.fx.Node):
57+
def skip_folding_ops(node: torch.fx.Node):
58+
# dtype op
59+
if node.target == acc_ops.dtype:
60+
return True
61+
# Now we do constant folding on traced module. We want to skip pattern like
62+
# weights -> quant -> dequant -> op during constant folding when the model is
63+
# a quantized int8 model.
64+
# quant_dequant
6165
if node.target != acc_ops.quantize_per_tensor:
6266
return False
6367
# If quantize_per_node -> dequantize, then skip folding.
@@ -66,7 +70,7 @@ def skip_folding_quant_dequant(node: torch.fx.Node):
6670
return True
6771
return False
6872

69-
const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant)
73+
const_split_mod = split_const_subgraphs(traced_mod, skip_folding_ops)
7074
const_split_mod.run_folding()
7175
return const_split_mod
7276

@@ -630,3 +634,35 @@ def fix_clamp_numerical_limits_to_fp16(
630634

631635
mod.recompile()
632636
return mod
637+
638+
639+
@log_before_after
640+
@validate_inference(atol=1e-3, rtol=1e-2)
641+
def remove_dtype_and_to_pattern(
642+
mod: torch.fx.GraphModule, input: Input
643+
) -> torch.fx.GraphModule:
644+
"""
645+
Remove this pattern since it is unnecessary to cast to dtype
646+
%dtype : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.dtype](args = (), kwargs = {input: %_attention_layers_0__uva})
647+
%to_18 : [#users=2] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.to_dtype](args = (), kwargs = {input: %x})
648+
"""
649+
for node in mod.graph.nodes:
650+
if node.op == "call_function" and node.target == acc_ops.dtype:
651+
# find its first user
652+
next_node = next(iter(node.users))
653+
# acc_op or pt op is treated differently
654+
input = (
655+
next_node.kwargs["input"]
656+
if "input" in next_node.kwargs
657+
else next_node.args[0]
658+
)
659+
if len(node.users) == 1 and (
660+
next_node.target == acc_ops.to_dtype or next_node.target == "to"
661+
):
662+
next_node.replace_all_uses_with(input)
663+
mod.graph.erase_node(next_node)
664+
mod.graph.erase_node(node)
665+
666+
mod.graph.eliminate_dead_code()
667+
mod.recompile()
668+
return mod

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch.fx.passes.pass_manager import inplace_wrapper, PassManager
99
from torch.fx.passes.shape_prop import ShapeProp
1010
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
11+
from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion
1112
from torch_tensorrt.fx.utils import LowerPrecision
1213

1314
from ..input_tensor_spec import generate_input_specs
@@ -229,10 +230,9 @@ def lower_func(split_result: SplitResult) -> nn.Module:
229230
submod = getattr(split_result.split_module, submod_name)
230231

231232
LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)
232-
233233
# Only acc submodules will be lowered.
234234
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
235-
_LOGGER.info(f"Now lowering submodule {submod_name}")
235+
_LOGGER.info(f"ACC submodule graph: {submod.graph}")
236236
lowering_start_time = datetime.datetime.now()
237237

238238
self.lower_setting.additional_inputs = (
@@ -251,6 +251,9 @@ def lower_func(split_result: SplitResult) -> nn.Module:
251251
_LOGGER.info(
252252
f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
253253
)
254+
else:
255+
_LOGGER.info(f"GPU submodule graph: {submod.graph}")
256+
apply_bfloat_float_conversion(submod, submod_inputs, submod_name)
254257

255258
return split_result.split_module
256259

0 commit comments

Comments
 (0)