@@ -335,6 +335,7 @@ def setup_quant_scales(self):
335
335
fc2_weight_block = self .w2_weight_scale ,
336
336
fc2_global = self .fc2_alpha ,
337
337
)
338
+
338
339
def is_trtllm (self ):
339
340
return self .moe_backend == "TRTLLM" and self .quant_config is not None
340
341
@@ -416,6 +417,7 @@ def create_weights(self):
416
417
self .register_parameter ("w2_weight_scaling_factor" ,
417
418
w2_weight_scaling_factor )
418
419
elif qc .quant_mode .has_nvfp4 ():
420
+ self .has_nv_fp4 = True
419
421
if self .is_trtllm ():
420
422
weight_dtype = float4_sf_dtype
421
423
weight_vec_size = torch .iinfo (weight_dtype ).bits // 4
@@ -668,7 +670,8 @@ def forward(
668
670
all_rank_num_tokens : Optional [List [int ]] = None ,
669
671
) -> torch .Tensor :
670
672
if self .is_cutlass ():
671
- return self .forward_cutlass (x , router_logits , min_latency_mode , output_dtype , all_rank_num_tokens )
673
+ return self .forward_cutlass (x , router_logits , min_latency_mode ,
674
+ output_dtype , all_rank_num_tokens )
672
675
elif self .is_trtllm ():
673
676
return self .forward_trtllmgen (x , router_logits )
674
677
else :
@@ -763,14 +766,7 @@ def forward_trtllmgen(self, x: torch.Tensor,
763
766
764
767
if self .quant_config and self .quant_config .quant_mode .has_fp8_block_scales (
765
768
):
766
- # TODO: We need a new kernel to support fp8 block scaling for blackwell
767
769
x_val , x_scale = torch .ops .trtllm .fp8_quantize_1x128 (x )
768
- m_4_align = (x .shape [0 ] + 3 ) // 4 * 4
769
- kscal_128 = (x .shape [1 ] + 127 ) // 128
770
- act_scal_elesize = kscal_128 * m_4_align
771
- x_scale = x_scale [:act_scal_elesize ]
772
- x_scale = x_scale .view (kscal_128 , m_4_align )
773
- x_scale = x_scale [:kscal_128 , :x .shape [0 ]].contiguous ()
774
770
775
771
final_hidden_states = torch .ops .trtllm .fp8_block_scale_moe_runner (
776
772
router_logits ,
0 commit comments