Skip to content

Commit 1e2882c

Browse files
committed
add wip flex dispatcher
1 parent 40c3530 commit 1e2882c

File tree

3 files changed

+698
-0
lines changed

3 files changed

+698
-0
lines changed

paddlenlp/transformers/fused_a2a.py

+209
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
# Portions of this code are from DeepSeek DeepEP project
3+
# Copyright (c) 2025 DeepSeek
4+
# Licensed under the MIT License - https://github.com/deepseek-ai/DeepEP/blob/main/LICENSE
5+
6+
try:
7+
from paddle.distributed import deep_ep
8+
9+
HAVE_DEEP_EP = True
10+
except ImportError:
11+
HAVE_DEEP_EP = False
12+
13+
import paddle
14+
from paddle.distributed.communication.group import Group
15+
from paddle.autograd import PyLayer
16+
17+
_buffer = None
18+
19+
20+
def get_hidden_bytes(x: paddle.Tensor) -> int:
21+
"""Calculate the number of hidden bytes for a tensor.
22+
23+
Args:
24+
x (paddle.Tensor): Input tensor
25+
26+
Returns:
27+
int: Number of hidden bytes
28+
"""
29+
return x.shape[1] * max(x.element_size(), 2)
30+
31+
32+
def get_buffer(group: Group, hidden_bytes: int):
33+
"""Get or create a buffer for all-to-all communication.
34+
35+
Args:
36+
group (paddle.distributed.ProcessGroup): Process group for communication
37+
hidden_bytes (int): Number of hidden bytes needed
38+
39+
Returns:
40+
Buffer: Communication buffer
41+
"""
42+
global _buffer
43+
num_nvl_bytes, num_rdma_bytes = 0, 0
44+
num_nvl_bytes = int(1e9)
45+
# TODO: hongqing
46+
# for config in (
47+
# deep_ep.Buffer.get_dispatch_config(group.world_size),
48+
# deep_ep.Buffer.get_combine_config(group.world_size),
49+
# ):
50+
# # Split long line for PEP8 compliance
51+
# num_nvl_bytes = max(
52+
# config.get_nvl_buffer_size_hint(hidden_bytes, group.world_size), num_nvl_bytes
53+
# )
54+
# num_rdma_bytes = max(
55+
# config.get_rdma_buffer_size_hint(hidden_bytes, group.world_size), num_rdma_bytes
56+
# )
57+
58+
# Allocate buffer if not existed or not enough buffer
59+
# NOTES: the adaptive routing configuration of the network **must be off**
60+
if (
61+
_buffer is None
62+
or _buffer.group != group
63+
or _buffer.num_nvl_bytes < num_nvl_bytes
64+
or _buffer.num_rdma_bytes < num_rdma_bytes
65+
):
66+
_buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes)
67+
return _buffer
68+
69+
70+
class FusedDispatch(PyLayer):
71+
"""Fused dispatch operation for MoE routing combining computation and communication."""
72+
73+
@staticmethod
74+
def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None):
75+
"""Forward pass of fused dispatch."""
76+
# Calculate layout before actual dispatch
77+
buffer = get_buffer(group, get_hidden_bytes(x))
78+
(
79+
num_tokens_per_rank,
80+
num_tokens_per_rdma_rank,
81+
num_tokens_per_expert,
82+
is_token_in_rank,
83+
previous_event,
84+
) = buffer.get_dispatch_layout(
85+
token_indices,
86+
num_experts,
87+
previous_event=None,
88+
async_finish=False,
89+
allocate_on_comm_stream=False,
90+
)
91+
92+
# Do MoE dispatch
93+
# NOTES: the CPU will wait for GPU's signal to arrive,
94+
# so this is not compatible with CUDA graph
95+
(
96+
recv_x,
97+
recv_token_indices,
98+
recv_token_probs,
99+
num_recv_tokens_per_expert_list,
100+
handle,
101+
event,
102+
) = buffer.dispatch(
103+
x,
104+
topk_idx=token_indices,
105+
topk_weights=token_probs.cast(paddle.float32),
106+
num_tokens_per_rank=num_tokens_per_rank,
107+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
108+
is_token_in_rank=is_token_in_rank,
109+
num_tokens_per_expert=num_tokens_per_expert,
110+
previous_event=None,
111+
async_finish=False,
112+
allocate_on_comm_stream=False,
113+
)
114+
115+
ctx.group = group
116+
ctx.handle = handle
117+
ctx.event = event
118+
tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list)
119+
120+
return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle)
121+
122+
@staticmethod
123+
def backward(
124+
ctx, grad_output, grad_token_indices, grad_token_probs, grad_tokens_per_expert, grad_handle
125+
):
126+
"""Backward pass of fused dispatch."""
127+
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
128+
handle = ctx.handle
129+
130+
grad_x, grad_token_probs, event = buffer.combine(
131+
grad_output.contiguous(),
132+
handle,
133+
topk_weights=grad_token_probs.cast(paddle.float32),
134+
previous_event=None,
135+
async_finish=False,
136+
allocate_on_comm_stream=False,
137+
)
138+
return grad_x, None, grad_token_probs, None, None, None
139+
140+
141+
class FusedCombine(PyLayer):
142+
"""Fused combine operation for MoE output combining computation and communication."""
143+
144+
@staticmethod
145+
def forward(ctx, x, group, handle, previous_event=None):
146+
"""Forward pass of fused combine."""
147+
buffer = get_buffer(group, get_hidden_bytes(x))
148+
combined_x, _, event = buffer.combine(
149+
x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False
150+
)
151+
ctx.handle = handle
152+
ctx.group = group
153+
ctx.previous_event=previous_event
154+
155+
# return combined_x, event
156+
return combined_x
157+
158+
@staticmethod
159+
def backward(ctx, grad_output):
160+
"""Backward pass of fused combine."""
161+
buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output))
162+
grad_x, _, _, _, _, event = buffer.dispatch(
163+
grad_output.contiguous(),
164+
handle=ctx.handle,
165+
previous_event=ctx.previous_event,
166+
async_finish=False,
167+
allocate_on_comm_stream=False,
168+
)
169+
return grad_x, (None, None, None, None, None, None)
170+
171+
172+
if HAVE_DEEP_EP:
173+
174+
def fused_dispatch(x, token_indices, token_probs, num_experts, group, previous_event=None):
175+
"""Perform fused dispatch operation if deep_ep is available.
176+
177+
Args:
178+
x: Input tensor [num_tokens, hidden_size]
179+
token_indices: Token routing indices [num_tokens, topk]
180+
token_probs: Token routing probabilities [num_tokens, topk]
181+
num_experts: Number of experts
182+
group: Process group
183+
previous_event: Previous CUDA event
184+
185+
Returns:
186+
Result of FusedDispatch
187+
"""
188+
return FusedDispatch.apply(
189+
x.contiguous(), token_indices, token_probs, num_experts, group, previous_event
190+
)
191+
192+
def fused_combine(x, group, handle, previous_event=None):
193+
"""Perform fused combine operation if deep_ep is available.
194+
195+
Args:
196+
x: Input tensor
197+
group: Process group
198+
handle: Communication handle
199+
previous_event: Previous CUDA event
200+
201+
Returns:
202+
Result of FusedCombine
203+
"""
204+
print(f'wsm fused_combine handle: {handle}')
205+
return FusedCombine.apply(x, group, handle, previous_event)
206+
207+
else:
208+
fused_dispatch = None
209+
fused_combine = None

paddlenlp/transformers/moe_utils.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import paddle
2+
from typing import Optional
3+
4+
def permute(
5+
tokens,
6+
routing_map,
7+
num_out_tokens: Optional[int] = None,
8+
fused: bool = False,
9+
drop_and_pad: bool = False,
10+
):
11+
"""Permute the tokens and probs based on the mask.
12+
Tokens with the same designated expert will be grouped together.
13+
The shape of mask is [tokens, num_experts], it indicates which experts were selected
14+
by each token.
15+
16+
When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to
17+
expert capacity. This function exploits this feature to use ops that support cuda graph.
18+
19+
Args:
20+
tokens (paddle.Tensor): The input token tensor, [num_tokens, hidden].
21+
routing_map (paddle.Tensor): The sparse token to expert mapping, [num_tokens, num_experts].
22+
num_out_tokens (int, optional): The number of output tokens. If None, it's set to
23+
the number of input tokens.
24+
fused (bool, optional): Whether use the fused permute function.
25+
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
26+
and pads the number of tokens to the expert capacity.
27+
If set to true, routing_map has a fixed number of non-zeros
28+
in each column.
29+
"""
30+
if fused:
31+
if not HAVE_TE or fused_permute is None:
32+
raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.")
33+
return fused_permute(tokens, routing_map, num_out_tokens)
34+
35+
num_tokens, hidden = tokens.shape
36+
num_experts = routing_map.shape[1]
37+
if drop_and_pad and not (num_out_tokens is None):
38+
capacity = num_out_tokens // num_experts
39+
assert not routing_map.requires_grad
40+
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
41+
routing_map = routing_map.to(dtype=paddle.int8).t().contiguous()
42+
# use argsort to put indices of all non-zeros in the beginning of list
43+
# and keep the first `capacity` number of indices
44+
sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[
45+
:, :capacity
46+
].contiguous()
47+
# flatten from [num_experts, capacity] to 1D
48+
sorted_indices = sorted_indices.view(-1)
49+
else:
50+
# mask [num_tokens, num_experts] -> [num_experts, num_tokens]
51+
routing_map = routing_map.cast(paddle.bool).T.contiguous()
52+
53+
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
54+
token_indices = (
55+
paddle.arange(num_tokens).unsqueeze(0).expand([num_experts, -1])
56+
)
57+
sorted_indices = token_indices.masked_select(routing_map)
58+
59+
# use the mapping to permute the tokens
60+
permuted_input = tokens.index_select(axis=0, index=sorted_indices)
61+
62+
return permuted_input, sorted_indices
63+
64+
65+
def unpermute(
66+
permuted_tokens: paddle.Tensor,
67+
sorted_indices: paddle.Tensor,
68+
restore_shape: paddle.shape,
69+
probs: paddle.Tensor = None,
70+
routing_map: paddle.Tensor = None,
71+
fused: bool = False,
72+
drop_and_pad: bool = False,
73+
):
74+
"""
75+
Restore the original order of tokens after permutation. If probs are provided, it
76+
will also apply them to the tokens before restoring the order.
77+
78+
When drop_and_pad=True, the tensors will have the following properties:
79+
- In routing_map, the number of non-zeros in each column equals to expert capacity
80+
- The size of sorted_indices equals to num_experts * capacity, each split of `capacity`
81+
contains the indices of tokens routed to an expert.
82+
This function exploits these features to use ops that support cuda graph.
83+
84+
Args:
85+
permuted_tokens (paddle.Tensor): The permuted token tensor.
86+
sorted_indices (paddle.Tensor): The indices used to sort the tokens.
87+
restore_shape (paddle.shape): The shape of the unpermuted tensor.
88+
probs (paddle.Tensor, optional): The unpermuted probs tensor,
89+
routing_map (paddle.Tensor, optional): Token to expert mapping, shape
90+
[num_tokens, num_experts].
91+
fused (bool, optional): Whether use the fused unpermute function.
92+
drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop
93+
and pads the number of tokens to the expert capacity.
94+
95+
Returns:
96+
paddle.Tensor: The tokens restored to their original order.
97+
"""
98+
if fused:
99+
if not HAVE_TE or fused_unpermute is None:
100+
raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.")
101+
return fused_unpermute(permuted_tokens, sorted_indices, probs, restore_shape)
102+
103+
_, hidden = restore_shape
104+
105+
if probs is not None:
106+
assert routing_map is not None, "Mask must be provided to permute the probs."
107+
if drop_and_pad:
108+
num_experts = routing_map.shape[1]
109+
num_permuted_tokens = sorted_indices.shape[0]
110+
capacity = num_permuted_tokens // num_experts
111+
num_unpermuted_tokens = probs.shape[0]
112+
113+
# [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens
114+
probs_T_1D = probs.T.contiguous().view(-1)
115+
116+
# get 1D indices of the probs selected by routing_map
117+
indices_dim0 = paddle.arange(num_experts).unsqueeze(-1)
118+
indices_dim1 = sorted_indices.view(num_experts, capacity)
119+
indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1)
120+
121+
# get probs from indices
122+
permuted_probs = probs_T_1D.index_select(axis=0, index=indices_1D)
123+
else:
124+
permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous())
125+
permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1)
126+
127+
# Create an output tensor filled with zeros
128+
output_tokens = paddle.zeros(
129+
restore_shape, dtype=permuted_tokens.dtype
130+
)
131+
# Scatter add the permuted_input back to the original positions
132+
output_tokens.put_along_axis_(axis=0, indices=sorted_indices.unsqueeze(1).expand([-1, hidden]), values=permuted_tokens, reduce='add', include_self=True)
133+
return output_tokens

0 commit comments

Comments
 (0)