|
15 | 15 | from ..model_config import ModelConfig
|
16 | 16 | from ..peft.lora.layer import LoraLayer, LoraModuleType
|
17 | 17 | from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
|
| 18 | +from .multi_stream_utils import maybe_execute_in_parallel |
18 | 19 | from .rms_norm import RMSNorm
|
19 | 20 | from .rotary_embedding import RotaryEmbedding
|
20 | 21 |
|
@@ -517,19 +518,14 @@ def forward(
|
517 | 518 | q, compressed_kv, k_pe = self.fused_a(hidden_states).split(
|
518 | 519 | [self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim],
|
519 | 520 | -1)
|
520 |
| - do_multi_stream = torch.cuda.is_current_stream_capturing( |
521 |
| - ) and self.aux_stream is not None |
522 |
| - if do_multi_stream: |
523 |
| - self.ln_events[0].record() |
524 |
| - compressed_kv = self.kv_a_layernorm(compressed_kv) |
525 |
| - with torch.cuda.stream(self.aux_stream): |
526 |
| - self.ln_events[0].wait() |
527 |
| - q = self.q_a_layernorm(q) |
528 |
| - self.ln_events[1].record() |
529 |
| - self.ln_events[1].wait() |
530 |
| - else: |
531 |
| - q = self.q_a_layernorm(q) |
532 |
| - compressed_kv = self.kv_a_layernorm(compressed_kv) |
| 521 | + |
| 522 | + q, compressed_kv = maybe_execute_in_parallel( |
| 523 | + lambda: self.q_a_layernorm(q), |
| 524 | + lambda: self.kv_a_layernorm(compressed_kv), |
| 525 | + self.ln_events[0], |
| 526 | + self.ln_events[1], |
| 527 | + self.aux_stream, |
| 528 | + ) |
533 | 529 |
|
534 | 530 | q = self.q_b_proj(q)
|
535 | 531 |
|
@@ -641,54 +637,65 @@ def forward_generation(
|
641 | 637 | attn_metadata: AttentionMetadata,
|
642 | 638 | ) -> torch.Tensor:
|
643 | 639 | num_tokens = q.shape[0]
|
644 |
| - latent_cache = torch.concat([compressed_kv, k_pe], dim=-1) |
645 |
| - |
646 | 640 | q_nope, q_pe = q.view([
|
647 | 641 | -1, self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim
|
648 | 642 | ]).split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
649 | 643 |
|
650 |
| - # fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank] |
651 |
| - # 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp |
652 |
| - fused_q = torch.empty( |
653 |
| - [ |
654 |
| - num_tokens, self.num_heads, |
655 |
| - (self.kv_lora_rank + self.qk_rope_head_dim) |
656 |
| - ], |
657 |
| - dtype=q.dtype, |
658 |
| - device=q.device, |
| 644 | + def _run_bmm(): |
| 645 | + # fused_q contains 1) the result of the following bmm with shape [num_tokens, num_heads, kv_lora_rank] |
| 646 | + # 2) rope(q_pe) with shape [num_tokens, num_heads, qk_rope_head_dim]. rope is applied inside AttentionOp |
| 647 | + fused_q = torch.empty( |
| 648 | + [ |
| 649 | + num_tokens, self.num_heads, |
| 650 | + (self.kv_lora_rank + self.qk_rope_head_dim) |
| 651 | + ], |
| 652 | + dtype=q.dtype, |
| 653 | + device=q.device, |
| 654 | + ) |
| 655 | + if self.k_b_proj_trans.dtype == torch.bfloat16: |
| 656 | + # [num_heads, num_tokens, self.qk_nope_head_dim] |
| 657 | + q_nope_t = q_nope.transpose(0, 1) |
| 658 | + # [num_heads, num_tokens, self.kv_lora_rank] |
| 659 | + q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) |
| 660 | + |
| 661 | + # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim] |
| 662 | + # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank] |
| 663 | + # The output of bmm is written directly into fused_q |
| 664 | + torch.ops.trtllm.bmm_out(q_nope_t, |
| 665 | + self.k_b_proj_trans.transpose(1, 2), |
| 666 | + q_nope_out) |
| 667 | + elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn: |
| 668 | + q_nope_fp8, q_nope_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( |
| 669 | + q_nope) |
| 670 | + # [num_heads, num_tokens, self.kv_lora_rank] |
| 671 | + q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) |
| 672 | + |
| 673 | + torch.ops.trtllm.fp8_block_scaling_bmm_out( |
| 674 | + q_nope_fp8, self.k_b_proj_trans, q_nope_scales, |
| 675 | + self.k_b_proj_trans_scale, q_nope_out) |
| 676 | + q_nope_scales = None |
| 677 | + else: |
| 678 | + raise NotImplementedError( |
| 679 | + f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.") |
| 680 | + |
| 681 | + fused_q = fused_q.view([ |
| 682 | + num_tokens, |
| 683 | + self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim) |
| 684 | + ]) |
| 685 | + return fused_q |
| 686 | + |
| 687 | + def _concat_kv_cache(): |
| 688 | + latent_cache = torch.concat([compressed_kv, k_pe], dim=-1) |
| 689 | + return latent_cache |
| 690 | + |
| 691 | + fused_q, latent_cache = maybe_execute_in_parallel( |
| 692 | + _run_bmm, |
| 693 | + _concat_kv_cache, |
| 694 | + self.ln_events[0], |
| 695 | + self.ln_events[1], |
| 696 | + self.aux_stream, |
659 | 697 | )
|
660 | 698 |
|
661 |
| - if self.k_b_proj_trans.dtype == torch.bfloat16: |
662 |
| - # [num_heads, num_tokens, self.qk_nope_head_dim] |
663 |
| - q_nope = q_nope.transpose(0, 1) |
664 |
| - # [num_heads, num_tokens, self.kv_lora_rank] |
665 |
| - q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) |
666 |
| - |
667 |
| - # [num_heads, num_tokens, self.qk_nope_head_dim] x [num_heads, kv_lora_rank, qk_nope_head_dim] |
668 |
| - # -> [num_heads, num_tokens, kv_lora_rank] -> [num_tokens, num_heads, kv_lora_rank] |
669 |
| - # The output of bmm is written directly into fused_q |
670 |
| - torch.ops.trtllm.bmm_out(q_nope, |
671 |
| - self.k_b_proj_trans.transpose(1, 2), |
672 |
| - q_nope_out) |
673 |
| - elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn: |
674 |
| - q_nope, q_nope_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( |
675 |
| - q_nope) |
676 |
| - # [num_heads, num_tokens, self.kv_lora_rank] |
677 |
| - q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) |
678 |
| - |
679 |
| - torch.ops.trtllm.fp8_block_scaling_bmm_out( |
680 |
| - q_nope, self.k_b_proj_trans, q_nope_scales, |
681 |
| - self.k_b_proj_trans_scale, q_nope_out) |
682 |
| - q_nope_scales = None |
683 |
| - else: |
684 |
| - raise NotImplementedError( |
685 |
| - f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.") |
686 |
| - |
687 |
| - fused_q = fused_q.view([ |
688 |
| - num_tokens, |
689 |
| - self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim) |
690 |
| - ]) |
691 |
| - |
692 | 699 | # out_scale = getattr(self.o_proj, "inv_input_scale", None)
|
693 | 700 | out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
|
694 | 701 |
|
|
0 commit comments