Skip to content

Commit 82b93d6

Browse files
committed
Fix internal issue for moe_matmul_ogs example
1 parent 864df06 commit 82b93d6

File tree

2 files changed

+75
-82
lines changed

2 files changed

+75
-82
lines changed

examples/moe_matmul_ogs.py

Lines changed: 43 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -39,51 +39,49 @@ def moe_matmul_ogs(
3939
start = expert_token_offsets[e_idx] # Starting index in sorted token array
4040
num_tokens = expert_token_counts[e_idx] # Number of tokens for this expert
4141

42-
# Skip experts with no assigned tokens
43-
if num_tokens != 0:
44-
# Tile over tokens and output features for this expert
45-
for tile_t, tile_n in hl.tile([max_T_per_expert, N]):
46-
# Get local token offsets for this tile
47-
# (i.e. the tile's corresponding chunk in [0 .. max_T_per_expert-1] token range)
48-
local_token_offsets = tile_t.index # [BLOCK_T]
49-
50-
# Create mask for valid tokens (some tiles may be partially filled)
51-
token_valid = local_token_offsets < num_tokens # bool[BLOCK_T]
52-
53-
# For invalid tokens, use 0 as a dummy index (will be masked out later)
54-
local_token_offsets_valid = torch.where(
55-
token_valid,
56-
local_token_offsets,
57-
0,
58-
) # [BLOCK_T]
59-
60-
# Convert local offsets to global sorted indices
61-
expert_sorted_token_indices = (
62-
start + local_token_offsets_valid
63-
) # [1, BLOCK_T]
64-
65-
# Map sorted indices back to global original token positions
66-
expert_orig_token_indices = sorted_to_orig_token_idx[
67-
expert_sorted_token_indices.squeeze(0)
68-
] # [BLOCK_T]
69-
70-
acc = hl.zeros([tile_t, tile_n], dtype=torch.float32)
71-
72-
# Perform tiled matrix multiplication: A[tokens, :] @ W[expert, :, :]
73-
for tile_k in hl.tile(K):
74-
A_frag = A[expert_orig_token_indices, tile_k] # [BLOCK_T, BLOCK_K]
75-
W_frag = W[e_idx, tile_k, tile_n] # [BLOCK_K, BLOCK_N]
76-
acc = torch.addmm(acc, A_frag, W_frag)
77-
78-
# Write results back to output tensor, masking out invalid tokens
79-
block_T = acc.size(0)
80-
block_N = acc.size(1)
81-
existing_values = C[expert_orig_token_indices, tile_n]
82-
mask_2d = token_valid.view(block_T, 1).expand(block_T, block_N)
83-
# Write results only for valid tokens, preserve existing values for invalid ones
84-
C[expert_orig_token_indices, tile_n] = torch.where(
85-
mask_2d, acc.to(C.dtype), existing_values
86-
)
42+
# Tile over tokens and output features for this expert
43+
for tile_t, tile_n in hl.tile([max_T_per_expert, N]):
44+
# Get local token offsets for this tile
45+
# (i.e. the tile's corresponding chunk in [0 .. max_T_per_expert-1] token range)
46+
local_token_offsets = tile_t.index # [BLOCK_T]
47+
48+
# Create mask for valid tokens (some tiles may be partially filled)
49+
token_valid = local_token_offsets < num_tokens # bool[BLOCK_T]
50+
51+
# For invalid tokens, use 0 as a dummy index (will be masked out later)
52+
local_token_offsets_valid = torch.where(
53+
token_valid,
54+
local_token_offsets,
55+
0,
56+
) # [BLOCK_T]
57+
58+
# Convert local offsets to global sorted indices
59+
expert_sorted_token_indices = (
60+
start + local_token_offsets_valid
61+
) # [1, BLOCK_T]
62+
63+
# Map sorted indices back to global original token positions
64+
expert_orig_token_indices = sorted_to_orig_token_idx[
65+
expert_sorted_token_indices.squeeze(0)
66+
] # [BLOCK_T]
67+
68+
acc = hl.zeros([tile_t, tile_n], dtype=torch.float32)
69+
70+
# Perform tiled matrix multiplication: A[tokens, :] @ W[expert, :, :]
71+
for tile_k in hl.tile(K):
72+
A_frag = A[expert_orig_token_indices, tile_k] # [BLOCK_T, BLOCK_K]
73+
W_frag = W[e_idx, tile_k, tile_n] # [BLOCK_K, BLOCK_N]
74+
acc = torch.addmm(acc, A_frag, W_frag)
75+
76+
# Write results back to output tensor, masking out invalid tokens
77+
block_T = acc.size(0)
78+
block_N = acc.size(1)
79+
existing_values = C[expert_orig_token_indices, tile_n]
80+
mask_2d = token_valid.view(block_T, 1).expand(block_T, block_N)
81+
# Write results only for valid tokens, preserve existing values for invalid ones
82+
C[expert_orig_token_indices, tile_n] = torch.where(
83+
mask_2d, acc.to(C.dtype), existing_values
84+
)
8785

8886
return C
8987

test/test_examples.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,43 +1466,38 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_
14661466
indices_0 = offset_0 + tl.zeros([1], tl.int32)
14671467
start = tl.load(expert_token_offsets + indices_0 * expert_token_offsets_stride_0, None)
14681468
num_tokens = tl.load(expert_token_counts + indices_0 * expert_token_counts_stride_0, None)
1469-
v_0 = tl.full([], 0, tl.int32)
1470-
v_1 = num_tokens != v_0
1471-
if v_1:
1472-
num_tokens_copy = num_tokens
1473-
start_copy = start
1474-
for offset_1 in range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1):
1475-
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1476-
mask_1 = indices_1 < max_T_per_expert
1477-
for offset_2 in range(0, N.to(tl.int32), _BLOCK_SIZE_2):
1478-
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1479-
mask_2 = indices_2 < N
1480-
num_tokens_copy_copy = num_tokens_copy
1481-
start_copy_copy = start_copy
1482-
v_2 = num_tokens_copy_copy[None]
1483-
v_3 = indices_1 < v_2
1484-
v_4 = tl.full([], 0, tl.int32)
1485-
v_5 = v_4[None]
1486-
v_6 = tl.where(v_3, indices_1, v_5)
1487-
v_7 = start_copy_copy[None]
1488-
v_8 = v_7 + v_6
1489-
squeeze = tl.reshape(v_8, [_BLOCK_SIZE_1])
1490-
expert_orig_token_indices = tl.load(sorted_to_orig_token_idx + squeeze * sorted_to_orig_token_idx_stride_0, mask_1, other=0)
1491-
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
1492-
for offset_3 in range(0, K.to(tl.int32), _BLOCK_SIZE_3):
1493-
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1494-
mask_3 = indices_3 < K
1495-
expert_orig_token_indices_copy = expert_orig_token_indices
1496-
acc_copy = acc
1497-
A_frag = tl.load(A + (expert_orig_token_indices_copy[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
1498-
W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
1499-
acc = tl.dot(A_frag, W_frag, acc=acc_copy, input_precision='tf32')
1500-
existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
1501-
view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1])
1502-
mask_2d = tl.broadcast_to(view, [_BLOCK_SIZE_1, _BLOCK_SIZE_2])
1503-
v_9 = acc.to(tl.float16)
1504-
v_10 = tl.where(mask_2d, v_9, existing_values)
1505-
tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :])
1469+
for offset_1 in range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1):
1470+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1471+
mask_1 = indices_1 < max_T_per_expert
1472+
for offset_2 in range(0, N.to(tl.int32), _BLOCK_SIZE_2):
1473+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1474+
mask_2 = indices_2 < N
1475+
num_tokens_copy = num_tokens
1476+
start_copy = start
1477+
v_0 = num_tokens_copy[None]
1478+
v_1 = indices_1 < v_0
1479+
v_2 = tl.full([], 0, tl.int32)
1480+
v_3 = v_2[None]
1481+
v_4 = tl.where(v_1, indices_1, v_3)
1482+
v_5 = start_copy[None]
1483+
v_6 = v_5 + v_4
1484+
squeeze = tl.reshape(v_6, [_BLOCK_SIZE_1])
1485+
expert_orig_token_indices = tl.load(sorted_to_orig_token_idx + squeeze * sorted_to_orig_token_idx_stride_0, mask_1, other=0)
1486+
acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
1487+
for offset_3 in range(0, K.to(tl.int32), _BLOCK_SIZE_3):
1488+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1489+
mask_3 = indices_3 < K
1490+
expert_orig_token_indices_copy = expert_orig_token_indices
1491+
acc_copy = acc
1492+
A_frag = tl.load(A + (expert_orig_token_indices_copy[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
1493+
W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
1494+
acc = tl.dot(A_frag, W_frag, acc=acc_copy, input_precision='tf32')
1495+
existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
1496+
view = tl.reshape(v_1, [_BLOCK_SIZE_1, 1])
1497+
mask_2d = tl.broadcast_to(view, [_BLOCK_SIZE_1, _BLOCK_SIZE_2])
1498+
v_7 = acc.to(tl.float16)
1499+
v_8 = tl.where(mask_2d, v_7, existing_values)
1500+
tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_8, mask_1[:, None] & mask_2[None, :])
15061501
15071502
def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch.Tensor, expert_token_offsets: torch.Tensor, sorted_to_orig_token_idx: torch.Tensor, max_T_per_expert_tensor: torch.Tensor):
15081503
T, K = A.shape

0 commit comments

Comments
 (0)