@@ -1466,43 +1466,38 @@ def _moe_matmul_ogs_kernel(expert_token_offsets, expert_token_counts, sorted_to_
1466
1466
indices_0 = offset_0 + tl.zeros([1], tl.int32)
1467
1467
start = tl.load(expert_token_offsets + indices_0 * expert_token_offsets_stride_0, None)
1468
1468
num_tokens = tl.load(expert_token_counts + indices_0 * expert_token_counts_stride_0, None)
1469
- v_0 = tl.full([], 0, tl.int32)
1470
- v_1 = num_tokens != v_0
1471
- if v_1:
1472
- num_tokens_copy = num_tokens
1473
- start_copy = start
1474
- for offset_1 in range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1):
1475
- indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1476
- mask_1 = indices_1 < max_T_per_expert
1477
- for offset_2 in range(0, N.to(tl.int32), _BLOCK_SIZE_2):
1478
- indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1479
- mask_2 = indices_2 < N
1480
- num_tokens_copy_copy = num_tokens_copy
1481
- start_copy_copy = start_copy
1482
- v_2 = num_tokens_copy_copy[None]
1483
- v_3 = indices_1 < v_2
1484
- v_4 = tl.full([], 0, tl.int32)
1485
- v_5 = v_4[None]
1486
- v_6 = tl.where(v_3, indices_1, v_5)
1487
- v_7 = start_copy_copy[None]
1488
- v_8 = v_7 + v_6
1489
- squeeze = tl.reshape(v_8, [_BLOCK_SIZE_1])
1490
- expert_orig_token_indices = tl.load(sorted_to_orig_token_idx + squeeze * sorted_to_orig_token_idx_stride_0, mask_1, other=0)
1491
- acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
1492
- for offset_3 in range(0, K.to(tl.int32), _BLOCK_SIZE_3):
1493
- indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1494
- mask_3 = indices_3 < K
1495
- expert_orig_token_indices_copy = expert_orig_token_indices
1496
- acc_copy = acc
1497
- A_frag = tl.load(A + (expert_orig_token_indices_copy[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
1498
- W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
1499
- acc = tl.dot(A_frag, W_frag, acc=acc_copy, input_precision='tf32')
1500
- existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
1501
- view = tl.reshape(v_3, [_BLOCK_SIZE_1, 1])
1502
- mask_2d = tl.broadcast_to(view, [_BLOCK_SIZE_1, _BLOCK_SIZE_2])
1503
- v_9 = acc.to(tl.float16)
1504
- v_10 = tl.where(mask_2d, v_9, existing_values)
1505
- tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_10, mask_1[:, None] & mask_2[None, :])
1469
+ for offset_1 in range(0, max_T_per_expert.to(tl.int32), _BLOCK_SIZE_1):
1470
+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1471
+ mask_1 = indices_1 < max_T_per_expert
1472
+ for offset_2 in range(0, N.to(tl.int32), _BLOCK_SIZE_2):
1473
+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1474
+ mask_2 = indices_2 < N
1475
+ num_tokens_copy = num_tokens
1476
+ start_copy = start
1477
+ v_0 = num_tokens_copy[None]
1478
+ v_1 = indices_1 < v_0
1479
+ v_2 = tl.full([], 0, tl.int32)
1480
+ v_3 = v_2[None]
1481
+ v_4 = tl.where(v_1, indices_1, v_3)
1482
+ v_5 = start_copy[None]
1483
+ v_6 = v_5 + v_4
1484
+ squeeze = tl.reshape(v_6, [_BLOCK_SIZE_1])
1485
+ expert_orig_token_indices = tl.load(sorted_to_orig_token_idx + squeeze * sorted_to_orig_token_idx_stride_0, mask_1, other=0)
1486
+ acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_2], 0.0, tl.float32)
1487
+ for offset_3 in range(0, K.to(tl.int32), _BLOCK_SIZE_3):
1488
+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1489
+ mask_3 = indices_3 < K
1490
+ expert_orig_token_indices_copy = expert_orig_token_indices
1491
+ acc_copy = acc
1492
+ A_frag = tl.load(A + (expert_orig_token_indices_copy[:, None] * A_stride_0 + indices_3[None, :] * A_stride_1), mask_1[:, None] & mask_3[None, :], other=0)
1493
+ W_frag = tl.load(W + (indices_0 * W_stride_0 + indices_3[:, None] * W_stride_1 + indices_2[None, :] * W_stride_2), mask_3[:, None] & mask_2[None, :], other=0)
1494
+ acc = tl.dot(A_frag, W_frag, acc=acc_copy, input_precision='tf32')
1495
+ existing_values = tl.load(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), mask_1[:, None] & mask_2[None, :], other=0)
1496
+ view = tl.reshape(v_1, [_BLOCK_SIZE_1, 1])
1497
+ mask_2d = tl.broadcast_to(view, [_BLOCK_SIZE_1, _BLOCK_SIZE_2])
1498
+ v_7 = acc.to(tl.float16)
1499
+ v_8 = tl.where(mask_2d, v_7, existing_values)
1500
+ tl.store(C + (expert_orig_token_indices[:, None] * C_stride_0 + indices_2[None, :] * C_stride_1), v_8, mask_1[:, None] & mask_2[None, :])
1506
1501
1507
1502
def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch.Tensor, expert_token_offsets: torch.Tensor, sorted_to_orig_token_idx: torch.Tensor, max_T_per_expert_tensor: torch.Tensor):
1508
1503
T, K = A.shape
0 commit comments