|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import functools |
3 | 4 | from pathlib import Path
|
4 | 5 | import unittest
|
5 | 6 |
|
@@ -367,3 +368,76 @@ def fn(x: torch.Tensor, block_size: int) -> torch.Tensor:
|
367 | 368 | )
|
368 | 369 | torch.testing.assert_close(result, torch.sin(args[0]))
|
369 | 370 | self.assertExpectedInline(code, """""")
|
| 371 | + |
| 372 | + def test_three_level_matmul(self): |
| 373 | + @helion.kernel(static_shapes=True) |
| 374 | + def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 375 | + m, k = x.size() |
| 376 | + k2, n = y.size() |
| 377 | + assert k == k2, f"size mismatch {k} != {k2}" |
| 378 | + out = torch.empty( |
| 379 | + [m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device |
| 380 | + ) |
| 381 | + for tile_m in hl.tile(m): |
| 382 | + for tile_n in hl.tile(n): |
| 383 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 384 | + for tile_k in hl.tile(k): |
| 385 | + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) |
| 386 | + out[tile_m, tile_n] = acc |
| 387 | + return out |
| 388 | + |
| 389 | + args = ( |
| 390 | + torch.randn([256, 512], device=DEVICE), |
| 391 | + torch.randn([512, 128], device=DEVICE), |
| 392 | + ) |
| 393 | + code, result = code_and_output(matmul, args, block_sizes=[16, 64, 64]) |
| 394 | + torch.testing.assert_close( |
| 395 | + result, functools.reduce(torch.matmul, args), atol=1e-1, rtol=1e-2 |
| 396 | + ) |
| 397 | + self.assertExpectedInline( |
| 398 | + code, |
| 399 | + """\ |
| 400 | +from __future__ import annotations |
| 401 | +
|
| 402 | +import torch |
| 403 | +import triton |
| 404 | +import triton.language as tl |
| 405 | +
|
| 406 | +@triton.jit |
| 407 | +def _matmul_kernel(x, y, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): |
| 408 | + pid_0 = tl.program_id(0) |
| 409 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 410 | + indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) |
| 411 | + for offset_1 in range(0, 128, _BLOCK_SIZE_1): |
| 412 | + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) |
| 413 | + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) |
| 414 | + for offset_2 in range(0, 512, _BLOCK_SIZE_2): |
| 415 | + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) |
| 416 | + acc_copy = acc |
| 417 | + load = tl.load(x + (indices_0[:, None] * 512 + indices_2[None, :] * 1), None) |
| 418 | + load_1 = tl.load(y + (indices_2[:, None] * 128 + indices_1[None, :] * 1), None) |
| 419 | + acc = tl.dot(load, load_1, acc=acc_copy, input_precision='tf32') |
| 420 | + tl.store(out + (indices_0[:, None] * 128 + indices_1[None, :] * 1), acc, None) |
| 421 | +
|
| 422 | +def matmul(x: torch.Tensor, y: torch.Tensor): |
| 423 | + m, k = x.size() |
| 424 | + k2, n = y.size() |
| 425 | + assert k == k2, f'size mismatch {k} != {k2}' |
| 426 | + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) |
| 427 | + _BLOCK_SIZE_0 = 16 |
| 428 | + _BLOCK_SIZE_1 = 64 |
| 429 | + _BLOCK_SIZE_2 = 64 |
| 430 | + _matmul_kernel[triton.cdiv(256, _BLOCK_SIZE_0),](x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) |
| 431 | + return out |
| 432 | +
|
| 433 | +def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor): |
| 434 | + m, k = x.size() |
| 435 | + k2, n = y.size() |
| 436 | + assert k == k2, f'size mismatch {k} != {k2}' |
| 437 | + out = torch.empty([m, n], dtype=torch.promote_types(x.dtype, y.dtype), device=x.device) |
| 438 | + _BLOCK_SIZE_0 = 16 |
| 439 | + _BLOCK_SIZE_1 = 64 |
| 440 | + _BLOCK_SIZE_2 = 64 |
| 441 | + from helion.runtime.precompile_shim import make_precompiler |
| 442 | + return make_precompiler(_matmul_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""", |
| 443 | + ) |
0 commit comments