|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# Portions of this code are from DeepSeek DeepEP project |
| 3 | +# Copyright (c) 2025 DeepSeek |
| 4 | +# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE |
| 5 | + |
| 6 | +try: |
| 7 | + from paddle.distributed import deep_ep |
| 8 | + |
| 9 | + HAVE_DEEP_EP = True |
| 10 | +except ImportError: |
| 11 | + HAVE_DEEP_EP = False |
| 12 | + |
| 13 | +import paddle |
| 14 | +from paddle.distributed.communication.group import Group |
| 15 | +from paddle.autograd import PyLayer |
| 16 | + |
| 17 | +_buffer = None |
| 18 | + |
| 19 | + |
| 20 | +def get_hidden_bytes(x: paddle.Tensor) -> int: |
| 21 | + """Calculate the number of hidden bytes for a tensor. |
| 22 | +
|
| 23 | + Args: |
| 24 | + x (paddle.Tensor): Input tensor |
| 25 | +
|
| 26 | + Returns: |
| 27 | + int: Number of hidden bytes |
| 28 | + """ |
| 29 | + return x.shape[1] * max(x.element_size(), 2) |
| 30 | + |
| 31 | + |
| 32 | +def get_buffer(group: Group, hidden_bytes: int): |
| 33 | + """Get or create a buffer for all-to-all communication. |
| 34 | +
|
| 35 | + Args: |
| 36 | + group (paddle.distributed.ProcessGroup): Process group for communication |
| 37 | + hidden_bytes (int): Number of hidden bytes needed |
| 38 | +
|
| 39 | + Returns: |
| 40 | + Buffer: Communication buffer |
| 41 | + """ |
| 42 | + global _buffer |
| 43 | + num_nvl_bytes, num_rdma_bytes = 0, 0 |
| 44 | + num_nvl_bytes = int(1e9) |
| 45 | + # TODO: hongqing |
| 46 | + # for config in ( |
| 47 | + # deep_ep.Buffer.get_dispatch_config(group.world_size), |
| 48 | + # deep_ep.Buffer.get_combine_config(group.world_size), |
| 49 | + # ): |
| 50 | + # # Split long line for PEP8 compliance |
| 51 | + # num_nvl_bytes = max( |
| 52 | + # config.get_nvl_buffer_size_hint(hidden_bytes, group.world_size), num_nvl_bytes |
| 53 | + # ) |
| 54 | + # num_rdma_bytes = max( |
| 55 | + # config.get_rdma_buffer_size_hint(hidden_bytes, group.world_size), num_rdma_bytes |
| 56 | + # ) |
| 57 | + |
| 58 | + # Allocate buffer if not existed or not enough buffer |
| 59 | + # NOTES: the adaptive routing configuration of the network **must be off** |
| 60 | + if ( |
| 61 | + _buffer is None |
| 62 | + or _buffer.group != group |
| 63 | + or _buffer.num_nvl_bytes < num_nvl_bytes |
| 64 | + or _buffer.num_rdma_bytes < num_rdma_bytes |
| 65 | + ): |
| 66 | + _buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes) |
| 67 | + return _buffer |
| 68 | + |
| 69 | + |
| 70 | +class FusedDispatch(PyLayer): |
| 71 | + """Fused dispatch operation for MoE routing combining computation and communication.""" |
| 72 | + |
| 73 | + @staticmethod |
| 74 | + def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None): |
| 75 | + """Forward pass of fused dispatch.""" |
| 76 | + # Calculate layout before actual dispatch |
| 77 | + buffer = get_buffer(group, get_hidden_bytes(x)) |
| 78 | + ( |
| 79 | + num_tokens_per_rank, |
| 80 | + num_tokens_per_rdma_rank, |
| 81 | + num_tokens_per_expert, |
| 82 | + is_token_in_rank, |
| 83 | + previous_event, |
| 84 | + ) = buffer.get_dispatch_layout( |
| 85 | + token_indices, |
| 86 | + num_experts, |
| 87 | + previous_event=None, |
| 88 | + async_finish=False, |
| 89 | + allocate_on_comm_stream=False, |
| 90 | + ) |
| 91 | + |
| 92 | + # Do MoE dispatch |
| 93 | + # NOTES: the CPU will wait for GPU's signal to arrive, |
| 94 | + # so this is not compatible with CUDA graph |
| 95 | + ( |
| 96 | + recv_x, |
| 97 | + recv_token_indices, |
| 98 | + recv_token_probs, |
| 99 | + num_recv_tokens_per_expert_list, |
| 100 | + handle, |
| 101 | + event, |
| 102 | + ) = buffer.dispatch( |
| 103 | + x, |
| 104 | + topk_idx=token_indices, |
| 105 | + topk_weights=token_probs.cast(paddle.float32), |
| 106 | + num_tokens_per_rank=num_tokens_per_rank, |
| 107 | + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, |
| 108 | + is_token_in_rank=is_token_in_rank, |
| 109 | + num_tokens_per_expert=num_tokens_per_expert, |
| 110 | + previous_event=None, |
| 111 | + async_finish=False, |
| 112 | + allocate_on_comm_stream=False, |
| 113 | + ) |
| 114 | + |
| 115 | + ctx.group = group |
| 116 | + ctx.handle = handle |
| 117 | + ctx.event = event |
| 118 | + tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list) |
| 119 | + |
| 120 | + return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle) |
| 121 | + |
| 122 | + @staticmethod |
| 123 | + def backward( |
| 124 | + ctx, grad_output, grad_token_indices, grad_token_probs, grad_tokens_per_expert, grad_handle |
| 125 | + ): |
| 126 | + """Backward pass of fused dispatch.""" |
| 127 | + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) |
| 128 | + handle = ctx.handle |
| 129 | + |
| 130 | + grad_x, grad_token_probs, event = buffer.combine( |
| 131 | + grad_output.contiguous(), |
| 132 | + handle, |
| 133 | + topk_weights=grad_token_probs.cast(paddle.float32), |
| 134 | + previous_event=None, |
| 135 | + async_finish=False, |
| 136 | + allocate_on_comm_stream=False, |
| 137 | + ) |
| 138 | + return grad_x, None, grad_token_probs, None, None, None |
| 139 | + |
| 140 | + |
| 141 | +class FusedCombine(PyLayer): |
| 142 | + """Fused combine operation for MoE output combining computation and communication.""" |
| 143 | + |
| 144 | + @staticmethod |
| 145 | + def forward(ctx, x, group, handle, previous_event=None): |
| 146 | + """Forward pass of fused combine.""" |
| 147 | + buffer = get_buffer(group, get_hidden_bytes(x)) |
| 148 | + combined_x, _, event = buffer.combine( |
| 149 | + x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False |
| 150 | + ) |
| 151 | + ctx.handle = handle |
| 152 | + ctx.group = group |
| 153 | + ctx.previous_event=previous_event |
| 154 | + |
| 155 | + # return combined_x, event |
| 156 | + return combined_x |
| 157 | + |
| 158 | + @staticmethod |
| 159 | + def backward(ctx, grad_output): |
| 160 | + """Backward pass of fused combine.""" |
| 161 | + buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) |
| 162 | + grad_x, _, _, _, _, event = buffer.dispatch( |
| 163 | + grad_output.contiguous(), |
| 164 | + handle=ctx.handle, |
| 165 | + previous_event=ctx.previous_event, |
| 166 | + async_finish=False, |
| 167 | + allocate_on_comm_stream=False, |
| 168 | + ) |
| 169 | + return grad_x, (None, None, None, None, None, None) |
| 170 | + |
| 171 | + |
| 172 | +if HAVE_DEEP_EP: |
| 173 | + |
| 174 | + def fused_dispatch(x, token_indices, token_probs, num_experts, group, previous_event=None): |
| 175 | + """Perform fused dispatch operation if deep_ep is available. |
| 176 | +
|
| 177 | + Args: |
| 178 | + x: Input tensor [num_tokens, hidden_size] |
| 179 | + token_indices: Token routing indices [num_tokens, topk] |
| 180 | + token_probs: Token routing probabilities [num_tokens, topk] |
| 181 | + num_experts: Number of experts |
| 182 | + group: Process group |
| 183 | + previous_event: Previous CUDA event |
| 184 | +
|
| 185 | + Returns: |
| 186 | + Result of FusedDispatch |
| 187 | + """ |
| 188 | + return FusedDispatch.apply( |
| 189 | + x.contiguous(), token_indices, token_probs, num_experts, group, previous_event |
| 190 | + ) |
| 191 | + |
| 192 | + def fused_combine(x, group, handle, previous_event=None): |
| 193 | + """Perform fused combine operation if deep_ep is available. |
| 194 | +
|
| 195 | + Args: |
| 196 | + x: Input tensor |
| 197 | + group: Process group |
| 198 | + handle: Communication handle |
| 199 | + previous_event: Previous CUDA event |
| 200 | +
|
| 201 | + Returns: |
| 202 | + Result of FusedCombine |
| 203 | + """ |
| 204 | + print(f'wsm fused_combine handle: {handle}') |
| 205 | + return FusedCombine.apply(x, group, handle, previous_event) |
| 206 | + |
| 207 | +else: |
| 208 | + fused_dispatch = None |
| 209 | + fused_combine = None |
0 commit comments