|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import unittest |
| 4 | + |
| 5 | +from expecttest import TestCase |
| 6 | +import torch |
| 7 | + |
| 8 | +import helion |
| 9 | +from helion._testing import DEVICE |
| 10 | +from helion._testing import code_and_output |
| 11 | +import helion.language as hl |
| 12 | + |
| 13 | + |
| 14 | +@helion.kernel() |
| 15 | +def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 16 | + """Test basic atomic_add functionality.""" |
| 17 | + for i in hl.tile([x.size(0)]): |
| 18 | + hl.atomic_add(x, [i], y[i]) |
| 19 | + return x |
| 20 | + |
| 21 | + |
| 22 | +@helion.kernel() |
| 23 | +def atomic_add_overlap_kernel( |
| 24 | + x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor |
| 25 | +) -> torch.Tensor: |
| 26 | + """Test atomic_add with overlapping indices.""" |
| 27 | + for i in hl.tile([y.size(0)]): |
| 28 | + idx = indices[i] |
| 29 | + hl.atomic_add(x, [idx], y[i]) |
| 30 | + return x |
| 31 | + |
| 32 | + |
| 33 | +class TestAtomicOperations(TestCase): |
| 34 | + maxDiff = 16384 |
| 35 | + |
| 36 | + def test_basic_atomic_add(self): |
| 37 | + # Basic test with sequential indices |
| 38 | + x = torch.zeros(10, device=DEVICE) |
| 39 | + y = torch.ones(10, device=DEVICE) |
| 40 | + args = (x, y) |
| 41 | + |
| 42 | + code, result = code_and_output( |
| 43 | + atomic_add_kernel, |
| 44 | + args, |
| 45 | + block_sizes=[32], |
| 46 | + ) |
| 47 | + |
| 48 | + expected = torch.ones(10, device=DEVICE) |
| 49 | + torch.testing.assert_close(result, expected) |
| 50 | + self.assertExpectedInline( |
| 51 | + code, |
| 52 | + """\ |
| 53 | +from __future__ import annotations |
| 54 | +
|
| 55 | +import torch |
| 56 | +import triton |
| 57 | +import triton.language as tl |
| 58 | +
|
| 59 | +@triton.jit |
| 60 | +def _atomic_add_kernel_kernel(x, y, x_size_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr): |
| 61 | + pid_0 = tl.program_id(0) |
| 62 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 63 | + indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) |
| 64 | + mask_0 = indices_0 < x_size_0 |
| 65 | + load = tl.load(y + indices_0 * y_stride_0, mask_0, other=0) |
| 66 | + tl.atomic_add(x + indices_0 * x_stride_0, load, mask=mask_0, sem='relaxed') |
| 67 | +
|
| 68 | +def atomic_add_kernel(x: torch.Tensor, y: torch.Tensor): |
| 69 | + \"\"\"Test basic atomic_add functionality.\"\"\" |
| 70 | + _BLOCK_SIZE_0 = 32 |
| 71 | + _atomic_add_kernel_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) |
| 72 | + return x |
| 73 | +
|
| 74 | +def _atomic_add_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor): |
| 75 | + \"\"\"Test basic atomic_add functionality.\"\"\" |
| 76 | + _BLOCK_SIZE_0 = 32 |
| 77 | + from helion.runtime.precompile_shim import make_precompiler |
| 78 | + return make_precompiler(_atomic_add_kernel_kernel)(x, y, x.size(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", |
| 79 | + ) |
| 80 | + |
| 81 | + def test_overlapping_atomic_add(self): |
| 82 | + # Test with overlapping indices |
| 83 | + x = torch.zeros(5, device=DEVICE) |
| 84 | + y = torch.ones(10, device=DEVICE) |
| 85 | + indices = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4], device=DEVICE) |
| 86 | + args = (x, y, indices) |
| 87 | + |
| 88 | + code, result = code_and_output( |
| 89 | + atomic_add_overlap_kernel, |
| 90 | + args, |
| 91 | + block_sizes=[32], |
| 92 | + ) |
| 93 | + |
| 94 | + expected = torch.ones(5, device=DEVICE) * 2 |
| 95 | + torch.testing.assert_close(result, expected) |
| 96 | + self.assertExpectedInline( |
| 97 | + code, |
| 98 | + """\ |
| 99 | +from __future__ import annotations |
| 100 | +
|
| 101 | +import torch |
| 102 | +import triton |
| 103 | +import triton.language as tl |
| 104 | +
|
| 105 | +@triton.jit |
| 106 | +def _atomic_add_overlap_kernel_kernel(y, indices, x, y_size_0, indices_stride_0, x_stride_0, y_stride_0, _BLOCK_SIZE_0: tl.constexpr): |
| 107 | + pid_0 = tl.program_id(0) |
| 108 | + offset_0 = pid_0 * _BLOCK_SIZE_0 |
| 109 | + indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) |
| 110 | + mask_0 = indices_0 < y_size_0 |
| 111 | + idx = tl.load(indices + indices_0 * indices_stride_0, mask_0, other=0) |
| 112 | + load_1 = tl.load(y + indices_0 * y_stride_0, mask_0, other=0) |
| 113 | + tl.atomic_add(x + idx * x_stride_0, load_1, mask=mask_0, sem='relaxed') |
| 114 | +
|
| 115 | +def atomic_add_overlap_kernel(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor): |
| 116 | + \"\"\"Test atomic_add with overlapping indices.\"\"\" |
| 117 | + _BLOCK_SIZE_0 = 32 |
| 118 | + _atomic_add_overlap_kernel_kernel[triton.cdiv(y.size(0), _BLOCK_SIZE_0),](y, indices, x, y.size(0), indices.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) |
| 119 | + return x |
| 120 | +
|
| 121 | +def _atomic_add_overlap_kernel_make_precompiler(x: torch.Tensor, y: torch.Tensor, indices: torch.Tensor): |
| 122 | + \"\"\"Test atomic_add with overlapping indices.\"\"\" |
| 123 | + _BLOCK_SIZE_0 = 32 |
| 124 | + from helion.runtime.precompile_shim import make_precompiler |
| 125 | + return make_precompiler(_atomic_add_overlap_kernel_kernel)(y, indices, x, y.size(0), indices.stride(0), x.stride(0), y.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""", |
| 126 | + ) |
| 127 | + |
| 128 | + def test_atomic_add_code_generation(self): |
| 129 | + """Test that the generated code contains atomic_add.""" |
| 130 | + x = torch.zeros(10, device=DEVICE) |
| 131 | + y = torch.ones(10, device=DEVICE) |
| 132 | + args = (x, y) |
| 133 | + |
| 134 | + code, _ = code_and_output(atomic_add_kernel, args) |
| 135 | + # Verify atomic_add appears in the generated code |
| 136 | + self.assertIn("atomic_add", code) |
| 137 | + # Verify the new signature format (using the target tensor and indices list) |
| 138 | + self.assertIn("tl.atomic_add(x + offset_0", code) |
| 139 | + |
| 140 | + |
| 141 | +if __name__ == "__main__": |
| 142 | + unittest.main() |
0 commit comments