@@ -99,14 +99,14 @@ def __init__(
99
99
n_group : int ,
100
100
topk_group : int ,
101
101
routed_scaling_factor : float ,
102
- is_thop : bool = True ,
102
+ is_fused : bool = True ,
103
103
):
104
104
super ().__init__ ()
105
105
self .top_k = top_k
106
106
self .topk_group = topk_group
107
107
self .n_group = n_group
108
108
self .routed_scaling_factor = routed_scaling_factor
109
- self .is_thop = is_thop
109
+ self .is_fused = is_fused
110
110
111
111
def noaux_tc (self , logits , e_score_correction_bias ):
112
112
n_group = self .n_group
@@ -121,7 +121,7 @@ def noaux_tc(self, logits, e_score_correction_bias):
121
121
"Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
122
122
)
123
123
124
- if self .is_thop == False :
124
+ if not self .is_fused :
125
125
group_scores = torch .sum (torch .topk (
126
126
scores_with_bias .view (scores_shape [:- 1 ] +
127
127
[n_group , scores_shape [- 1 ] // n_group ]),
@@ -171,7 +171,7 @@ def apply(
171
171
return topk_indices .to (torch .int32 ), topk_values .to (torch .float32 )
172
172
173
173
174
- class Deepseekv3Gate (BaseMoeRoutingMethod ):
174
+ class DeepseekV3Gate (BaseMoeRoutingMethod ):
175
175
176
176
def __init__ (
177
177
self ,
@@ -182,7 +182,8 @@ def __init__(
182
182
topk_group : int ,
183
183
routed_scaling_factor : float ,
184
184
dtype : Optional [torch .dtype ] = None ,
185
- is_thop : bool = True ,
185
+ fuse_routing_kernel : bool = True ,
186
+ apply_routing : bool = False ,
186
187
):
187
188
super ().__init__ ()
188
189
self .weight = nn .Parameter (torch .empty ((num_experts , hidden_size ),
@@ -192,18 +193,20 @@ def __init__(
192
193
(num_experts ), dtype = torch .float32 ),
193
194
requires_grad = False )
194
195
195
- # TODO: e_score_correction_bias makes sense to live in this gate class, but it is needed for the routing impl
196
- # So we don't run into issues with weight loading, we make this gate object the BaseMoeRoutingMethod
197
- # and then dispatch to the routing impl for the actual implementation.
198
- # This is a bit of a hack and we should clean this up in the future.
196
+ assert not apply_routing , "DeepseekV3Gate routing is called inside MoE"
197
+
198
+ # TODO: e_score_correction_bias belongs in this gate class but is required by the routing impl.
199
+ # To avoid weight-loading issues, we treat this gate as the BaseMoeRoutingMethod and dispatch to the routing impl.
200
+ # This is a temporary hack that should be refactored later.
199
201
self .routing_impl = Deepseekv3RoutingImpl (
200
202
top_k = top_k ,
201
203
n_group = n_group ,
202
204
topk_group = topk_group ,
203
205
routed_scaling_factor = routed_scaling_factor ,
204
- is_thop = is_thop )
206
+ is_fused = fuse_routing_kernel )
205
207
206
208
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
209
+ # router gemm
207
210
logits = torch .ops .trtllm .cublas_mm (hidden_states ,
208
211
self .weight .t (),
209
212
bias = None ,
@@ -219,6 +222,7 @@ def load_weights(self, weights: List[Dict]):
219
222
weights [0 ]["e_score_correction_bias" ][:].to (torch .float32 ))
220
223
221
224
def apply (self , logits : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
225
+ # topk routing
222
226
return self .routing_impl .apply (logits , self .e_score_correction_bias )
223
227
224
228
@property
@@ -247,22 +251,24 @@ def __init__(self,
247
251
config = model_config .pretrained_config
248
252
self .top_k = top_k
249
253
self .use_dp = model_config .mapping .enable_attention_dp
250
- self .gate = Deepseekv3Gate (
254
+ self .gate = DeepseekV3Gate (
251
255
hidden_size ,
252
256
num_experts ,
253
257
top_k = top_k ,
254
258
n_group = config .n_group ,
255
259
topk_group = config .topk_group ,
256
260
routed_scaling_factor = config .routed_scaling_factor ,
257
- dtype = dtype )
261
+ dtype = dtype ,
262
+ fuse_routing_kernel = True ,
263
+ apply_routing = False )
258
264
self .experts = FusedMoE (
259
265
num_experts = num_experts ,
260
266
routing_method = self .gate .routing_method ,
261
267
hidden_size = hidden_size ,
262
268
intermediate_size = intermediate_size ,
263
269
dtype = dtype ,
264
270
reduce_results =
265
- False , # In both low latency and attention dp scenarios , FusedMoE needs not to do allreduce inside op .
271
+ False , # In both low‑ latency and attention‑DP modes , FusedMoE skips the in‑op all‑reduce .
266
272
model_config = model_config ,
267
273
aux_stream = aux_stream_dict [AuxStreamType .MoeChunkingOverlap ])
268
274
@@ -282,6 +288,7 @@ def __init__(self,
282
288
# If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce.
283
289
if shared_tp_size != model_config .mapping .tp_size :
284
290
self .shared_output_scale = shared_tp_size / model_config .mapping .tp_size
291
+
285
292
self .shared_experts = GatedMLP (
286
293
hidden_size = hidden_size ,
287
294
intermediate_size = shared_expert_intermediate_size ,
@@ -301,36 +308,34 @@ def __init__(self,
301
308
302
309
def compute_routed_output (self , hidden_states , hidden_states_fp4 ,
303
310
all_rank_num_tokens , min_latency_mode ):
311
+ # max-throughput
304
312
if self .use_dp and self .mapping .tp_size > 1 :
305
313
max_num_token = max (all_rank_num_tokens )
306
314
hidden_states = torch .nn .functional .pad (
307
315
hidden_states ,
308
316
(0 , 0 , 0 , max_num_token - hidden_states .shape [0 ]))
317
+ # FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
318
+ # to reduce allreduce BW
309
319
if disable_fp4_allgather ():
310
320
hidden_states = allgather (hidden_states ,
311
321
self .mapping ,
312
322
gather_dim = 0 )
323
+
313
324
router_logits = self .gate (hidden_states )
314
325
315
- if hidden_states_fp4 is not None :
316
- routed_output = self .experts (hidden_states_fp4 ,
317
- router_logits ,
318
- min_latency_mode ,
319
- output_dtype = hidden_states .dtype )
320
- else :
321
- routed_output = self .experts (
322
- hidden_states ,
323
- router_logits ,
324
- min_latency_mode ,
325
- all_rank_num_tokens = all_rank_num_tokens )
326
+ routed_output = self .experts (hidden_states_fp4 or hidden_states ,
327
+ router_logits ,
328
+ min_latency_mode ,
329
+ output_dtype = hidden_states .dtype ,
330
+ all_rank_num_tokens = all_rank_num_tokens )
326
331
327
332
return routed_output
328
333
329
334
def forward (
330
335
self ,
331
336
hidden_states : torch .Tensor ,
332
337
hidden_states_fp4 : Optional [Fp4QuantizedTensor ] = None ,
333
- all_rank_num_tokens = None ,
338
+ all_rank_num_tokens : Optional [ list [ int ]] = None ,
334
339
final_all_reduce_params : Optional [AllReduceParams ] = None ,
335
340
min_latency_mode : Optional [bool ] = False ,
336
341
) -> torch .Tensor :
@@ -357,15 +362,16 @@ def _compute_routed_output():
357
362
358
363
if min_latency_mode :
359
364
return [shared_output , * routed_output ]
365
+ else :
366
+ assert shared_output .size () == routed_output .size (
367
+ ), f'unmatched tensor shape'
368
+ final_hidden_states = shared_output + routed_output
369
+ if not self .use_dp and self .mapping .tp_size > 1 :
370
+ final_hidden_states = self .all_reduce (
371
+ final_hidden_states ,
372
+ all_reduce_params = final_all_reduce_params )
360
373
361
- assert shared_output .size () == routed_output .size (
362
- ), f'unmatched tensor shape'
363
- final_hidden_states = shared_output + routed_output
364
- if not self .use_dp and self .mapping .tp_size > 1 :
365
- final_hidden_states = self .all_reduce (
366
- final_hidden_states , all_reduce_params = final_all_reduce_params )
367
-
368
- return final_hidden_states
374
+ return final_hidden_states
369
375
370
376
371
377
class DeepseekV3DecoderLayer (DecoderLayer ):
@@ -381,31 +387,35 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
381
387
self .num_shared_experts = config .n_shared_experts
382
388
self .top_k = config .num_experts_per_tok
383
389
390
+ self .mapping = model_config .mapping
391
+ mapping = self .mapping
392
+
384
393
self .self_attn = DeepseekV3Attention (
385
394
model_config ,
386
395
layer_idx = layer_idx ,
387
396
aux_stream = aux_stream_dict [AuxStreamType .Attention ])
388
397
self .fusion_config = EagerFusionConfig ()
389
- self .enable_attention_dp = model_config . mapping .enable_attention_dp
390
- self .mlp_tp_size = model_config . mapping .tp_size
398
+ self .enable_attention_dp = mapping .enable_attention_dp
399
+ self .mlp_tp_size = mapping .tp_size
391
400
392
- self .enable_fusion = os .environ .get (
393
- "TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED" , "0" ) == "0"
394
-
395
- pp_layer_offset = model_config .mapping .pp_layers (
396
- config .num_hidden_layers )[0 ]
401
+ pp_layer_offset = mapping .pp_layers (config .num_hidden_layers )[0 ]
397
402
global_layer_idx = pp_layer_offset + layer_idx
398
403
404
+ enable_fusion = os .environ .get ("TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED" ,
405
+ "0" ) == "0"
406
+ self .enable_fusion = enable_fusion and not self .enable_attention_dp
407
+
399
408
self .is_nvfp4 = model_config .quant_config .layer_quant_mode .has_nvfp4 ()
409
+ has_tp = mapping .has_tp ()
410
+ has_pp = mapping .has_pp ()
400
411
401
412
if (config .n_routed_experts is not None
402
413
and global_layer_idx >= config .first_k_dense_replace
403
414
and global_layer_idx % config .moe_layer_freq == 0 ):
404
- self .fusion_config .PRE_MOE_FUSION = self .enable_fusion and model_config .mapping .has_tp (
405
- ) and not self .enable_attention_dp
406
- self .fusion_config .POST_MOE_FUSION = self .enable_fusion and model_config .mapping .has_tp (
407
- ) and not self .enable_attention_dp and not model_config .mapping .has_pp (
408
- )
415
+
416
+ self .fusion_config .PRE_MOE_FUSION = self .enable_fusion and has_tp
417
+ self .fusion_config .POST_MOE_FUSION = self .fusion_config .PRE_MOE_FUSION and not has_pp
418
+
409
419
self .mlp = Deepseekv3MoE (
410
420
num_experts = self .num_experts ,
411
421
top_k = self .top_k ,
@@ -429,15 +439,14 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
429
439
self .mlp_tp_size = math .gcd (
430
440
math .gcd (
431
441
config .intermediate_size // 128 ,
432
- model_config . mapping .tp_size ,
442
+ mapping .tp_size ,
433
443
),
434
- model_config .mapping .
435
- gpus_per_node , # Avoid costly inter-node TP
444
+ mapping .gpus_per_node , # Avoid costly inter-node TP
436
445
)
437
- self . fusion_config . PRE_MLP_FUSION = self . enable_fusion and model_config . mapping . has_tp (
438
- ) and self .is_nvfp4 and not self .enable_attention_dp
439
- self .fusion_config .POST_MLP_FUSION = self .enable_fusion and self .mlp_tp_size > 1 and not self . enable_attention_dp and not model_config . mapping . has_pp (
440
- )
446
+
447
+ self . fusion_config . PRE_MLP_FUSION = self .enable_fusion and has_tp and self .is_nvfp4
448
+ self .fusion_config .POST_MLP_FUSION = self .enable_fusion and self .mlp_tp_size > 1 and not has_pp
449
+
441
450
self .mlp = GatedMLP (hidden_size = config .hidden_size ,
442
451
intermediate_size = config .intermediate_size ,
443
452
bias = False ,
@@ -450,17 +459,21 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
450
459
eps = config .rms_norm_eps ,
451
460
dtype = config .torch_dtype )
452
461
462
+ self .disable_attn_allreduce = (self .fusion_config .PRE_MOE_FUSION
463
+ or self .fusion_config .PRE_MLP_FUSION
464
+ or self .mapping .tp_size == 1
465
+ or self .enable_attention_dp )
466
+
453
467
self .post_attention_layernorm = RMSNorm (hidden_size = config .hidden_size ,
454
468
eps = config .rms_norm_eps ,
455
469
dtype = config .torch_dtype )
456
- self .mapping = model_config .mapping
457
470
self .layer_idx = layer_idx
458
471
self .all_reduce = AllReduce (self .mapping )
459
472
self .next_layer_layernorm : RMSNorm = None
460
473
461
474
self .deepseek_allreduce_disabled = os .environ .get (
462
475
"TRTLLM_DEEPSEEK_ALLREDUCE_FUSION_DISABLED" , "0" ) == "1"
463
- if model_config . mapping .is_multi_node ():
476
+ if mapping .is_multi_node ():
464
477
self .deepseek_allreduce_disabled = True
465
478
466
479
if not self .deepseek_allreduce_disabled :
@@ -474,15 +487,6 @@ def forward(
474
487
residual : torch .Tensor ,
475
488
** kwargs ,
476
489
) -> torch .Tensor :
477
-
478
- # deepseek allreduce kernel is better when m < 512, two shot(128~512) has acc bug, waive
479
- using_prev_fusion = self .deepseek_allreduce_disabled or hidden_states .size (
480
- 0 ) > 128
481
-
482
- min_latency_mode = True if hidden_states .size (
483
- 0
484
- ) <= 128 and self .fusion_config .POST_MOE_FUSION and self .is_nvfp4 else False
485
-
486
490
if residual is None :
487
491
residual = hidden_states
488
492
hidden_states = self .input_layernorm (hidden_states )
@@ -492,13 +496,19 @@ def forward(
492
496
position_ids = position_ids ,
493
497
hidden_states = hidden_states ,
494
498
attn_metadata = attn_metadata ,
495
- all_reduce_params = AllReduceParams (enable_allreduce = not (
496
- self .fusion_config .PRE_MOE_FUSION
497
- or self .fusion_config .PRE_MLP_FUSION
498
- or self .mapping .tp_size == 1 or self .enable_attention_dp )),
499
+ all_reduce_params = AllReduceParams (
500
+ enable_allreduce = not self .disable_attn_allreduce ),
499
501
** kwargs ,
500
502
)
501
503
504
+ # deepseek allreduce kernel is better when m < 512, two shot(128~512) has acc bug, waive
505
+ using_prev_fusion = self .deepseek_allreduce_disabled or hidden_states .size (
506
+ 0 ) > 128
507
+
508
+ min_latency_mode = True if hidden_states .size (
509
+ 0
510
+ ) <= 128 and self .fusion_config .POST_MOE_FUSION and self .is_nvfp4 else False
511
+
502
512
if self .fusion_config .PRE_MOE_FUSION :
503
513
# Custom AR Fusion for DeepseekV3
504
514
if using_prev_fusion :
@@ -710,9 +720,8 @@ def forward(
710
720
position_ids = position_ids ,
711
721
hidden_states = hidden_states ,
712
722
attn_metadata = attn_metadata ,
713
- all_reduce_params = AllReduceParams (enable_allreduce = not (
714
- self .fusion_config .PRE_MOE_FUSION or self .mapping .tp_size == 1
715
- or self .enable_attention_dp )),
723
+ all_reduce_params = AllReduceParams (
724
+ enable_allreduce = not self .disable_attn_allreduce ),
716
725
** kwargs ,
717
726
)
718
727
@@ -858,6 +867,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
858
867
if model_config .spec_config is not None :
859
868
assert not model_config .mapping .has_pp (
860
869
), "PP + MTP combination is not supported"
870
+
861
871
model_nextn = model_config .spec_config .num_nextn_predict_layers
862
872
ckpt_nextn = self .config .num_nextn_predict_layers
863
873
self .num_hidden_layers = self .config .num_hidden_layers
0 commit comments