Skip to content

Commit 7dc86e1

Browse files
committed
Fix rebase issues
Signed-off-by: Hao Lu <[email protected]>
1 parent f9f404e commit 7dc86e1

35 files changed

+24
-292
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
add_subdirectory(fmha)
1919
add_subdirectory(blockscaleGemm)
20-
add_subdirectory(fp8BlockScaleMoe)
20+
add_subdirectory(blockScaleMoe)

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
*/
1616

1717
#include "tensorrt_llm/kernels/quantization.h"
18-
#include "tensorrt_llm/kernels/trtllmGenKernels/fp8BlockScaleMoe/runner.h"
18+
#include "tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h"
1919
#include "tensorrt_llm/runtime/torchUtils.h"
2020
#include "tensorrt_llm/thop/thUtils.h"
2121
#include <ATen/cuda/EmptyTensor.h>

cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
#include "tensorrt_llm/kernels/trtllmGenKernels/fp8BlockScaleMoe/runner.h"
17+
#include "tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h"
1818
#include "tensorrt_llm/runtime/torchUtils.h"
1919
#include "tensorrt_llm/thop/thUtils.h"
2020
#include <ATen/cuda/EmptyTensor.h>

examples/mmlu_llmapi.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ def parse_args():
241241
default='TRTLLM',
242242
choices=['TRTLLM', 'FLASHINFER'],
243243
help='Attention kernel for PyTorch. Ignored for TRT backend.')
244+
parser.add_argument('--moe_backend',
245+
type=str,
246+
default='CUTLASS',
247+
choices=['CUTLASS', 'TRTLLM'])
244248
parser.add_argument("--enable_chunked_prefill",
245249
action="store_true",
246250
help="Exercises the chunked prefill inference feature.")
@@ -307,7 +311,7 @@ def main():
307311
assert args.engine_dir is None, "pytorch backend does not need TRT Engine"
308312
config = PyTorchConfig(
309313
attn_backend=args.attn_backend,
310-
moe_backend='TRTLLM',
314+
moe_backend=args.moe_backend,
311315
enable_overlap_scheduler=args.enable_overlap_scheduler,
312316
torch_compile_enabled=args.torch_compile)
313317
llm = tensorrt_llm._torch.LLM(

tensorrt_llm/_torch/modules/attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def __init__(
361361

362362
if quant_mode.has_fp8_block_scales():
363363
mla_weight_dtype = torch.float8_e4m3fn
364+
# TODO: remove hack for fp8 Deepseek on SM100
364365
if config.moe_backend == "TRTLLM":
365366
mla_weight_dtype = dtype
366367
else:
@@ -487,7 +488,6 @@ def forward(
487488
attn_metadata: AttentionMetadata,
488489
all_reduce_params: Optional[AllReduceParams] = None,
489490
) -> torch.Tensor:
490-
assert hidden_states.dtype == torch.bfloat16, "Just for TRTLLM FP8 E2E test"
491491
if self.is_lite:
492492
compressed_kv, k_pe = self.fused_a(hidden_states).split(
493493
[self.kv_lora_rank, self.qk_rope_head_dim], -1)

tensorrt_llm/_torch/modules/fused_moe.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def setup_quant_scales(self):
335335
fc2_weight_block=self.w2_weight_scale,
336336
fc2_global=self.fc2_alpha,
337337
)
338+
338339
def is_trtllm(self):
339340
return self.moe_backend == "TRTLLM" and self.quant_config is not None
340341

@@ -416,6 +417,7 @@ def create_weights(self):
416417
self.register_parameter("w2_weight_scaling_factor",
417418
w2_weight_scaling_factor)
418419
elif qc.quant_mode.has_nvfp4():
420+
self.has_nv_fp4 = True
419421
if self.is_trtllm():
420422
weight_dtype = float4_sf_dtype
421423
weight_vec_size = torch.iinfo(weight_dtype).bits // 4
@@ -668,7 +670,8 @@ def forward(
668670
all_rank_num_tokens: Optional[List[int]] = None,
669671
) -> torch.Tensor:
670672
if self.is_cutlass():
671-
return self.forward_cutlass(x, router_logits, min_latency_mode, output_dtype, all_rank_num_tokens)
673+
return self.forward_cutlass(x, router_logits, min_latency_mode,
674+
output_dtype, all_rank_num_tokens)
672675
elif self.is_trtllm():
673676
return self.forward_trtllmgen(x, router_logits)
674677
else:
@@ -763,14 +766,7 @@ def forward_trtllmgen(self, x: torch.Tensor,
763766

764767
if self.quant_config and self.quant_config.quant_mode.has_fp8_block_scales(
765768
):
766-
# TODO: We need a new kernel to support fp8 block scaling for blackwell
767769
x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
768-
m_4_align = (x.shape[0] + 3) // 4 * 4
769-
kscal_128 = (x.shape[1] + 127) // 128
770-
act_scal_elesize = kscal_128 * m_4_align
771-
x_scale = x_scale[:act_scal_elesize]
772-
x_scale = x_scale.view(kscal_128, m_4_align)
773-
x_scale = x_scale[:kscal_128, :x.shape[0]].contiguous()
774770

775771
final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner(
776772
router_logits,

tensorrt_llm/_torch/modules/linear.py

+3-13
Original file line numberDiff line numberDiff line change
@@ -322,15 +322,9 @@ def apply_linear(self, input, weight, bias):
322322
if input.dtype == torch.float8_e4m3fn:
323323
input = input.to(torch.bfloat16) * self.input_scale
324324
assert input.dtype == torch.bfloat16
325-
# TODO: We need a new kernel to support fp8 block scaling for blackwell
326-
act_input_fp8, a_scale = torch.ops.trtllm.fp8_quantize_1x128(
325+
326+
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
327327
input)
328-
m_4_align = (input.shape[0] + 3) // 4 * 4
329-
kscal_128 = (input.shape[1] + 127) // 128
330-
act_scal_elesize = kscal_128 * m_4_align
331-
a_scale = a_scale[:act_scal_elesize]
332-
a_scale = a_scale.view(kscal_128, m_4_align)
333-
act_input_sf = a_scale[:kscal_128, :input.shape[0]].contiguous()
334328

335329
output = torch.ops.trtllm.fp8_block_scaling_gemm(
336330
act_input_fp8, self.weight, act_input_sf, self.weight_scale)
@@ -402,11 +396,7 @@ def load_weights(self, weights: List[Dict]):
402396
assert self._weights_created
403397

404398
def copy(dst: Parameter, src: torch.Tensor):
405-
# TODO: Update this once we have BMM FP8 working with blackwell
406-
#assert dst.dtype == src.dtype, f"Incompatible dtype. dst: {dst.dtype}, src: {src.dtype}"
407-
assert dst.dtype == src.dtype or (
408-
dst.dtype == torch.bfloat16 and src.dtype == torch.float8_e4m3fn
409-
), f"Incompatible dtype. dst: {dst.dtype}, src: {src.dtype}"
399+
assert dst.dtype == src.dtype, f"Incompatible dtype. dst: {dst.dtype}, src: {src.dtype}"
410400
dst.data.copy_(src)
411401

412402
weight_mode = self.weights_loading_config.weight_mode

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def create_py_executor(executor_config: ExecutorConfig,
183183
executor_config.pytorch_backend_config.use_kv_cache = False
184184

185185
kv_cache_max_tokens = None
186-
# TODO: remove this once we have a loop fix for routing token limit
186+
187187
if model_engine.model.model_config.is_generation:
188188
kv_cache_max_tokens = estimate_max_kv_cache_tokens(
189189
model_engine, executor_config, mapping)

tests/_torch/multi_gpu_modeling/test_deepseek.py

-209
This file was deleted.

0 commit comments

Comments
 (0)