Skip to content

Commit 2cbbdc9

Browse files
committed
[Deepseek] Redesign multi-stream API
Signed-off-by: Hao Lu <[email protected]>
1 parent c51e90d commit 2cbbdc9

File tree

3 files changed

+125
-73
lines changed

3 files changed

+125
-73
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
from ..modules.fused_moe import BaseMoeRoutingMethod, FusedMoE
2424
from ..modules.gated_mlp import GatedMLP
2525
from ..modules.linear import Linear
26+
from ..modules.multi_stream_utils import maybe_execute_in_parallel
2627
from ..modules.rms_norm import RMSNorm
2728
from ..modules.rotary_embedding import RotaryEmbedding
2829
from ..pipeline_interface import PipelineInterface
29-
from ..pyexecutor.cuda_graph_runner import is_graph_capturing
3030
from ..speculative import MTPEagleWorker, MTPSpecMetadata, MTPWorker
3131
from ..utils import (AuxStreamType, EventType, Fp4QuantizedTensor,
3232
disable_fp4_allgather)
@@ -351,27 +351,25 @@ def forward(
351351
) -> torch.Tensor:
352352
if min_latency_mode:
353353
assert not self.use_dp
354-
# Only enable multi-stream for cuda graph since switch stream has extra host overhead
355-
# This design is mainly for low latency use case. Need to improve for max throughput use case.
356-
do_multi_stream = is_graph_capturing()
357-
if do_multi_stream:
358-
self.event_dict[EventType.Main].record()
359-
shared_output = self.shared_experts(hidden_states)
360-
if self.shared_output_scale is not None:
361-
shared_output *= self.shared_output_scale
362-
if do_multi_stream:
363-
with torch.cuda.stream(self.aux_stream):
364-
self.event_dict[EventType.Main].wait()
365-
routed_output = self.compute_routed_output(
366-
hidden_states, hidden_states_fp4, all_rank_num_tokens,
367-
min_latency_mode)
368-
self.event_dict[EventType.MoeShared].record()
369-
self.event_dict[EventType.MoeShared].wait()
370-
else:
354+
355+
def _compute_shared_output():
356+
shared_output = self.shared_experts(hidden_states)
357+
if self.shared_output_scale is not None:
358+
shared_output *= self.shared_output_scale
359+
return shared_output
360+
361+
def _compute_routed_output():
371362
routed_output = self.compute_routed_output(hidden_states,
372363
hidden_states_fp4,
373364
all_rank_num_tokens,
374365
min_latency_mode)
366+
return routed_output
367+
368+
shared_output, routed_output = maybe_execute_in_parallel(
369+
_compute_shared_output, _compute_routed_output,
370+
self.event_dict[EventType.Main],
371+
self.event_dict[EventType.MoeShared], self.aux_stream)
372+
375373
if min_latency_mode:
376374
return [shared_output, *routed_output]
377375

tensorrt_llm/_torch/modules/attention.py

+62-55
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..model_config import ModelConfig
1616
from ..peft.lora.layer import LoraLayer, LoraModuleType
1717
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
18+
from .multi_stream_utils import maybe_execute_in_parallel
1819
from .rms_norm import RMSNorm
1920
from .rotary_embedding import RotaryEmbedding
2021

@@ -517,19 +518,14 @@ def forward(
517518
q, compressed_kv, k_pe = self.fused_a(hidden_states).split(
518519
[self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim],
519520
-1)
520-
do_multi_stream = torch.cuda.is_current_stream_capturing(
521-
) and self.aux_stream is not None
522-
if do_multi_stream:
523-
self.ln_events[0].record()
524-
compressed_kv = self.kv_a_layernorm(compressed_kv)
525-
with torch.cuda.stream(self.aux_stream):
526-
self.ln_events[0].wait()
527-
q = self.q_a_layernorm(q)
528-
self.ln_events[1].record()
529-
self.ln_events[1].wait()
530-
else:
531-
q = self.q_a_layernorm(q)
532-
compressed_kv = self.kv_a_layernorm(compressed_kv)
521+
522+
q, compressed_kv = maybe_execute_in_parallel(
523+
lambda: self.q_a_layernorm(q),
524+
lambda: self.kv_a_layernorm(compressed_kv),
525+
self.ln_events[0],
526+
self.ln_events[1],
527+
self.aux_stream,
528+
)
533529

534530
q = self.q_b_proj(q)
535531

@@ -641,54 +637,65 @@ def forward_generation(
641637
attn_metadata: AttentionMetadata,
642638
) -> torch.Tensor:
643639
num_tokens = q.shape[0]
644-
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)
645-
646640
q_nope, q_pe = q.view([
647641
-1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim
648642
]).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
649643

650-
# fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank]
651-
# 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp
652-
fused_q = torch.empty(
653-
[
654-
num_tokens, self.num_heads,
655-
(self.kv_lora_rank + self.qk_rope_head_dim)
656-
],
657-
dtype=q.dtype,
658-
device=q.device,
644+
def _run_bmm():
645+
# fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank]
646+
# 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp
647+
fused_q = torch.empty(
648+
[
649+
num_tokens, self.num_heads,
650+
(self.kv_lora_rank + self.qk_rope_head_dim)
651+
],
652+
dtype=q.dtype,
653+
device=q.device,
654+
)
655+
if self.k_b_proj_trans.dtype == torch.bfloat16:
656+
# [num_heads, num_tokens, self.qk_nope_head_dim]
657+
q_nope_t = q_nope.transpose(0, 1)
658+
# [num_heads, num_tokens, self.kv_lora_rank]
659+
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
660+
661+
# [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim]
662+
# -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank]
663+
# The output of bmm is written directly into fused_q
664+
torch.ops.trtllm.bmm_out(q_nope_t,
665+
self.k_b_proj_trans.transpose(1, 2),
666+
q_nope_out)
667+
elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn:
668+
q_nope_fp8, q_nope_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
669+
q_nope)
670+
# [num_heads, num_tokens, self.kv_lora_rank]
671+
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
672+
673+
torch.ops.trtllm.fp8_block_scaling_bmm_out(
674+
q_nope_fp8, self.k_b_proj_trans, q_nope_scales,
675+
self.k_b_proj_trans_scale, q_nope_out)
676+
q_nope_scales = None
677+
else:
678+
raise NotImplementedError(
679+
f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")
680+
681+
fused_q = fused_q.view([
682+
num_tokens,
683+
self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim)
684+
])
685+
return fused_q
686+
687+
def _concat_kv_cache():
688+
latent_cache = torch.concat([compressed_kv, k_pe], dim=-1)
689+
return latent_cache
690+
691+
fused_q, latent_cache = maybe_execute_in_parallel(
692+
_run_bmm,
693+
_concat_kv_cache,
694+
self.ln_events[0],
695+
self.ln_events[1],
696+
self.aux_stream,
659697
)
660698

661-
if self.k_b_proj_trans.dtype == torch.bfloat16:
662-
# [num_heads, num_tokens, self.qk_nope_head_dim]
663-
q_nope = q_nope.transpose(0, 1)
664-
# [num_heads, num_tokens, self.kv_lora_rank]
665-
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
666-
667-
# [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim]
668-
# -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank]
669-
# The output of bmm is written directly into fused_q
670-
torch.ops.trtllm.bmm_out(q_nope,
671-
self.k_b_proj_trans.transpose(1, 2),
672-
q_nope_out)
673-
elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn:
674-
q_nope, q_nope_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
675-
q_nope)
676-
# [num_heads, num_tokens, self.kv_lora_rank]
677-
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
678-
679-
torch.ops.trtllm.fp8_block_scaling_bmm_out(
680-
q_nope, self.k_b_proj_trans, q_nope_scales,
681-
self.k_b_proj_trans_scale, q_nope_out)
682-
q_nope_scales = None
683-
else:
684-
raise NotImplementedError(
685-
f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")
686-
687-
fused_q = fused_q.view([
688-
num_tokens,
689-
self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim)
690-
])
691-
692699
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
693700
out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
694701

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Any, Callable, Optional
2+
3+
import torch
4+
5+
from ..pyexecutor.cuda_graph_runner import is_graph_capturing
6+
7+
8+
def maybe_execute_in_parallel(
9+
fn0: Callable,
10+
fn1: Callable,
11+
event0: torch.cuda.Event,
12+
event1: torch.cuda.Event,
13+
aux_stream: Optional[torch.cuda.Stream] = None) -> tuple[Any, Any]:
14+
"""Utility function to run two functions in two cuda streams in parallel. Multi-stream is
15+
only enabled when cuda graph is turned on because switch stream has extra host overhead.
16+
17+
This design is mainly for low latency use case. It needs to be improved for max throughput
18+
use case.
19+
For simplicity, fn0 and fn1 do not support inputs.
20+
21+
Args:
22+
fn0 (Callable): callable for the default stream
23+
fn1 (Callable): callable for the second stream, aux_stream
24+
event0 (torch.cuda.Event): cuda event for fn0
25+
event1 (torch.cuda.Event): cuda event for fn1
26+
aux_stream (Optional[torch.cuda.Stream]): the second cuda stream for fn1.
27+
Mutil-stream is disabled when aux_stream is None.
28+
29+
Returns:
30+
tuple[Any, Any]: the return values of fn0() and fn1()
31+
"""
32+
33+
do_multi_stream = is_graph_capturing() and aux_stream is not None
34+
35+
if do_multi_stream:
36+
event0.record()
37+
result0 = fn0()
38+
39+
with torch.cuda.stream(aux_stream):
40+
event0.wait()
41+
result1 = fn1()
42+
event1.record()
43+
event1.wait()
44+
else:
45+
result0 = fn0()
46+
result1 = fn1()
47+
return (result0, result1)

0 commit comments

Comments
 (0)