Skip to content

Commit 593cb9f

Browse files
committed
Clean up modeling_deepseek.py
Signed-off-by: Hao Lu <[email protected]>
1 parent 26ebd95 commit 593cb9f

File tree

3 files changed

+104
-88
lines changed

3 files changed

+104
-88
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

+80-70
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ def __init__(
9999
n_group: int,
100100
topk_group: int,
101101
routed_scaling_factor: float,
102-
is_thop: bool = True,
102+
is_fused: bool = True,
103103
):
104104
super().__init__()
105105
self.top_k = top_k
106106
self.topk_group = topk_group
107107
self.n_group = n_group
108108
self.routed_scaling_factor = routed_scaling_factor
109-
self.is_thop = is_thop
109+
self.is_fused = is_fused
110110

111111
def noaux_tc(self, logits, e_score_correction_bias):
112112
n_group = self.n_group
@@ -121,7 +121,7 @@ def noaux_tc(self, logits, e_score_correction_bias):
121121
"Detected NAN in the tensor scores_with_bias. Please check if it matches the expectation."
122122
)
123123

124-
if self.is_thop == False:
124+
if not self.is_fused:
125125
group_scores = torch.sum(torch.topk(
126126
scores_with_bias.view(scores_shape[:-1] +
127127
[n_group, scores_shape[-1] // n_group]),
@@ -171,7 +171,7 @@ def apply(
171171
return topk_indices.to(torch.int32), topk_values.to(torch.float32)
172172

173173

174-
class Deepseekv3Gate(BaseMoeRoutingMethod):
174+
class DeepseekV3Gate(BaseMoeRoutingMethod):
175175

176176
def __init__(
177177
self,
@@ -182,7 +182,8 @@ def __init__(
182182
topk_group: int,
183183
routed_scaling_factor: float,
184184
dtype: Optional[torch.dtype] = None,
185-
is_thop: bool = True,
185+
fuse_routing_kernel: bool = True,
186+
apply_routing: bool = False,
186187
):
187188
super().__init__()
188189
self.weight = nn.Parameter(torch.empty((num_experts, hidden_size),
@@ -192,18 +193,20 @@ def __init__(
192193
(num_experts), dtype=torch.float32),
193194
requires_grad=False)
194195

195-
# TODO: e_score_correction_bias makes sense to live in this gate class, but it is needed for the routing impl
196-
# So we don't run into issues with weight loading, we make this gate object the BaseMoeRoutingMethod
197-
# and then dispatch to the routing impl for the actual implementation.
198-
# This is a bit of a hack and we should clean this up in the future.
196+
assert not apply_routing, "DeepseekV3Gate routing is called inside MoE"
197+
198+
# TODO: e_score_correction_bias belongs in this gate class but is required by the routing impl.
199+
# To avoid weight-loading issues, we treat this gate as the BaseMoeRoutingMethod and dispatch to the routing impl.
200+
# This is a temporary hack that should be refactored later.
199201
self.routing_impl = Deepseekv3RoutingImpl(
200202
top_k=top_k,
201203
n_group=n_group,
202204
topk_group=topk_group,
203205
routed_scaling_factor=routed_scaling_factor,
204-
is_thop=is_thop)
206+
is_fused=fuse_routing_kernel)
205207

206208
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
209+
# router gemm
207210
logits = torch.ops.trtllm.cublas_mm(hidden_states,
208211
self.weight.t(),
209212
bias=None,
@@ -219,6 +222,7 @@ def load_weights(self, weights: List[Dict]):
219222
weights[0]["e_score_correction_bias"][:].to(torch.float32))
220223

221224
def apply(self, logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
225+
# topk routing
222226
return self.routing_impl.apply(logits, self.e_score_correction_bias)
223227

224228
@property
@@ -247,22 +251,24 @@ def __init__(self,
247251
config = model_config.pretrained_config
248252
self.top_k = top_k
249253
self.use_dp = model_config.mapping.enable_attention_dp
250-
self.gate = Deepseekv3Gate(
254+
self.gate = DeepseekV3Gate(
251255
hidden_size,
252256
num_experts,
253257
top_k=top_k,
254258
n_group=config.n_group,
255259
topk_group=config.topk_group,
256260
routed_scaling_factor=config.routed_scaling_factor,
257-
dtype=dtype)
261+
dtype=dtype,
262+
fuse_routing_kernel=True,
263+
apply_routing=False)
258264
self.experts = FusedMoE(
259265
num_experts=num_experts,
260266
routing_method=self.gate.routing_method,
261267
hidden_size=hidden_size,
262268
intermediate_size=intermediate_size,
263269
dtype=dtype,
264270
reduce_results=
265-
False, # In both low latency and attention dp scenarios, FusedMoE needs not to do allreduce inside op.
271+
False, # In both lowlatency and attention‑DP modes, FusedMoE skips the in‑op all‑reduce.
266272
model_config=model_config,
267273
aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap])
268274

@@ -282,6 +288,7 @@ def __init__(self,
282288
# If shared_tp_size has been overridden, the output of shared experts needs to be scaled down accordingly before all-reduce.
283289
if shared_tp_size != model_config.mapping.tp_size:
284290
self.shared_output_scale = shared_tp_size / model_config.mapping.tp_size
291+
285292
self.shared_experts = GatedMLP(
286293
hidden_size=hidden_size,
287294
intermediate_size=shared_expert_intermediate_size,
@@ -301,36 +308,34 @@ def __init__(self,
301308

302309
def compute_routed_output(self, hidden_states, hidden_states_fp4,
303310
all_rank_num_tokens, min_latency_mode):
311+
# max-throughput
304312
if self.use_dp and self.mapping.tp_size > 1:
305313
max_num_token = max(all_rank_num_tokens)
306314
hidden_states = torch.nn.functional.pad(
307315
hidden_states,
308316
(0, 0, 0, max_num_token - hidden_states.shape[0]))
317+
# FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization
318+
# to reduce allreduce BW
309319
if disable_fp4_allgather():
310320
hidden_states = allgather(hidden_states,
311321
self.mapping,
312322
gather_dim=0)
323+
313324
router_logits = self.gate(hidden_states)
314325

315-
if hidden_states_fp4 is not None:
316-
routed_output = self.experts(hidden_states_fp4,
317-
router_logits,
318-
min_latency_mode,
319-
output_dtype=hidden_states.dtype)
320-
else:
321-
routed_output = self.experts(
322-
hidden_states,
323-
router_logits,
324-
min_latency_mode,
325-
all_rank_num_tokens=all_rank_num_tokens)
326+
routed_output = self.experts(hidden_states_fp4 or hidden_states,
327+
router_logits,
328+
min_latency_mode,
329+
output_dtype=hidden_states.dtype,
330+
all_rank_num_tokens=all_rank_num_tokens)
326331

327332
return routed_output
328333

329334
def forward(
330335
self,
331336
hidden_states: torch.Tensor,
332337
hidden_states_fp4: Optional[Fp4QuantizedTensor] = None,
333-
all_rank_num_tokens=None,
338+
all_rank_num_tokens: Optional[list[int]] = None,
334339
final_all_reduce_params: Optional[AllReduceParams] = None,
335340
min_latency_mode: Optional[bool] = False,
336341
) -> torch.Tensor:
@@ -357,15 +362,16 @@ def _compute_routed_output():
357362

358363
if min_latency_mode:
359364
return [shared_output, *routed_output]
365+
else:
366+
assert shared_output.size() == routed_output.size(
367+
), f'unmatched tensor shape'
368+
final_hidden_states = shared_output + routed_output
369+
if not self.use_dp and self.mapping.tp_size > 1:
370+
final_hidden_states = self.all_reduce(
371+
final_hidden_states,
372+
all_reduce_params=final_all_reduce_params)
360373

361-
assert shared_output.size() == routed_output.size(
362-
), f'unmatched tensor shape'
363-
final_hidden_states = shared_output + routed_output
364-
if not self.use_dp and self.mapping.tp_size > 1:
365-
final_hidden_states = self.all_reduce(
366-
final_hidden_states, all_reduce_params=final_all_reduce_params)
367-
368-
return final_hidden_states
374+
return final_hidden_states
369375

370376

371377
class DeepseekV3DecoderLayer(DecoderLayer):
@@ -381,31 +387,35 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
381387
self.num_shared_experts = config.n_shared_experts
382388
self.top_k = config.num_experts_per_tok
383389

390+
self.mapping = model_config.mapping
391+
mapping = self.mapping
392+
384393
self.self_attn = DeepseekV3Attention(
385394
model_config,
386395
layer_idx=layer_idx,
387396
aux_stream=aux_stream_dict[AuxStreamType.Attention])
388397
self.fusion_config = EagerFusionConfig()
389-
self.enable_attention_dp = model_config.mapping.enable_attention_dp
390-
self.mlp_tp_size = model_config.mapping.tp_size
398+
self.enable_attention_dp = mapping.enable_attention_dp
399+
self.mlp_tp_size = mapping.tp_size
391400

392-
self.enable_fusion = os.environ.get(
393-
"TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED", "0") == "0"
394-
395-
pp_layer_offset = model_config.mapping.pp_layers(
396-
config.num_hidden_layers)[0]
401+
pp_layer_offset = mapping.pp_layers(config.num_hidden_layers)[0]
397402
global_layer_idx = pp_layer_offset + layer_idx
398403

404+
enable_fusion = os.environ.get("TRTLLM_DEEPSEEK_EAGER_FUSION_DISABLED",
405+
"0") == "0"
406+
self.enable_fusion = enable_fusion and not self.enable_attention_dp
407+
399408
self.is_nvfp4 = model_config.quant_config.layer_quant_mode.has_nvfp4()
409+
has_tp = mapping.has_tp()
410+
has_pp = mapping.has_pp()
400411

401412
if (config.n_routed_experts is not None
402413
and global_layer_idx >= config.first_k_dense_replace
403414
and global_layer_idx % config.moe_layer_freq == 0):
404-
self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and model_config.mapping.has_tp(
405-
) and not self.enable_attention_dp
406-
self.fusion_config.POST_MOE_FUSION = self.enable_fusion and model_config.mapping.has_tp(
407-
) and not self.enable_attention_dp and not model_config.mapping.has_pp(
408-
)
415+
416+
self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
417+
self.fusion_config.POST_MOE_FUSION = self.fusion_config.PRE_MOE_FUSION and not has_pp
418+
409419
self.mlp = Deepseekv3MoE(
410420
num_experts=self.num_experts,
411421
top_k=self.top_k,
@@ -429,15 +439,14 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
429439
self.mlp_tp_size = math.gcd(
430440
math.gcd(
431441
config.intermediate_size // 128,
432-
model_config.mapping.tp_size,
442+
mapping.tp_size,
433443
),
434-
model_config.mapping.
435-
gpus_per_node, # Avoid costly inter-node TP
444+
mapping.gpus_per_node, # Avoid costly inter-node TP
436445
)
437-
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and model_config.mapping.has_tp(
438-
) and self.is_nvfp4 and not self.enable_attention_dp
439-
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and self.mlp_tp_size > 1 and not self.enable_attention_dp and not model_config.mapping.has_pp(
440-
)
446+
447+
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_tp and self.is_nvfp4
448+
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and self.mlp_tp_size > 1 and not has_pp
449+
441450
self.mlp = GatedMLP(hidden_size=config.hidden_size,
442451
intermediate_size=config.intermediate_size,
443452
bias=False,
@@ -450,17 +459,21 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
450459
eps=config.rms_norm_eps,
451460
dtype=config.torch_dtype)
452461

462+
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
463+
or self.fusion_config.PRE_MLP_FUSION
464+
or self.mapping.tp_size == 1
465+
or self.enable_attention_dp)
466+
453467
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
454468
eps=config.rms_norm_eps,
455469
dtype=config.torch_dtype)
456-
self.mapping = model_config.mapping
457470
self.layer_idx = layer_idx
458471
self.all_reduce = AllReduce(self.mapping)
459472
self.next_layer_layernorm: RMSNorm = None
460473

461474
self.deepseek_allreduce_disabled = os.environ.get(
462475
"TRTLLM_DEEPSEEK_ALLREDUCE_FUSION_DISABLED", "0") == "1"
463-
if model_config.mapping.is_multi_node():
476+
if mapping.is_multi_node():
464477
self.deepseek_allreduce_disabled = True
465478

466479
if not self.deepseek_allreduce_disabled:
@@ -474,15 +487,6 @@ def forward(
474487
residual: torch.Tensor,
475488
**kwargs,
476489
) -> torch.Tensor:
477-
478-
# deepseek allreduce kernel is better when m < 512, two shot(128~512) has acc bug, waive
479-
using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size(
480-
0) > 128
481-
482-
min_latency_mode = True if hidden_states.size(
483-
0
484-
) <= 128 and self.fusion_config.POST_MOE_FUSION and self.is_nvfp4 else False
485-
486490
if residual is None:
487491
residual = hidden_states
488492
hidden_states = self.input_layernorm(hidden_states)
@@ -492,13 +496,19 @@ def forward(
492496
position_ids=position_ids,
493497
hidden_states=hidden_states,
494498
attn_metadata=attn_metadata,
495-
all_reduce_params=AllReduceParams(enable_allreduce=not (
496-
self.fusion_config.PRE_MOE_FUSION
497-
or self.fusion_config.PRE_MLP_FUSION
498-
or self.mapping.tp_size == 1 or self.enable_attention_dp)),
499+
all_reduce_params=AllReduceParams(
500+
enable_allreduce=not self.disable_attn_allreduce),
499501
**kwargs,
500502
)
501503

504+
# deepseek allreduce kernel is better when m < 512, two shot(128~512) has acc bug, waive
505+
using_prev_fusion = self.deepseek_allreduce_disabled or hidden_states.size(
506+
0) > 128
507+
508+
min_latency_mode = True if hidden_states.size(
509+
0
510+
) <= 128 and self.fusion_config.POST_MOE_FUSION and self.is_nvfp4 else False
511+
502512
if self.fusion_config.PRE_MOE_FUSION:
503513
# Custom AR Fusion for DeepseekV3
504514
if using_prev_fusion:
@@ -710,9 +720,8 @@ def forward(
710720
position_ids=position_ids,
711721
hidden_states=hidden_states,
712722
attn_metadata=attn_metadata,
713-
all_reduce_params=AllReduceParams(enable_allreduce=not (
714-
self.fusion_config.PRE_MOE_FUSION or self.mapping.tp_size == 1
715-
or self.enable_attention_dp)),
723+
all_reduce_params=AllReduceParams(
724+
enable_allreduce=not self.disable_attn_allreduce),
716725
**kwargs,
717726
)
718727

@@ -858,6 +867,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
858867
if model_config.spec_config is not None:
859868
assert not model_config.mapping.has_pp(
860869
), "PP + MTP combination is not supported"
870+
861871
model_nextn = model_config.spec_config.num_nextn_predict_layers
862872
ckpt_nextn = self.config.num_nextn_predict_layers
863873
self.num_hidden_layers = self.config.num_hidden_layers

tensorrt_llm/_torch/pipeline_interface.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class PipelineInterface:
1616
- Slicing: pp[start:end]
1717
1818
Note: When using this interface in pp, the packing/unpacking and send/recv
19-
operations must be used symmetrically within stage and between succsive ranks.
19+
operations must be used symmetrically within stage and between successive ranks.
2020
"""
2121
_pp_comm = None
2222

0 commit comments

Comments
 (0)