Skip to content

Commit 6497bea

Browse files
yf225pytorchmergebot
authored andcommitted
Support python test/test_X.py command for all unit test files (#60)
Several of our unit test files already support `python test/test_X.py` command which is convenient for quick testing. This PR adds this support to all remaining unit test files. Pull Request resolved: #60 Approved by: https://github.com/drisspg, https://github.com/oulgen, https://github.com/jansel
1 parent e76907d commit 6497bea

10 files changed

+69
-18
lines changed

test/test_autotuner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
import random
55
import tempfile
6+
import unittest
67
from unittest.mock import patch
78

89
from expecttest import TestCase
@@ -192,3 +193,7 @@ def add(a, b):
192193
)
193194
result = add(*args)
194195
torch.testing.assert_close(result, sum(args))
196+
197+
198+
if __name__ == "__main__":
199+
unittest.main()

test/test_closures.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4+
import unittest
45

56
from expecttest import TestCase
67
import torch
@@ -303,3 +304,7 @@ def _call_func_arg_on_host_make_precompiler(a, alloc):
303304
from helion.runtime.precompile_shim import make_precompiler
304305
return make_precompiler(_call_func_arg_on_host_kernel)(a, out, a.size(0), a.stride(0), out.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)""",
305306
)
307+
308+
309+
if __name__ == "__main__":
310+
unittest.main()

test/test_constexpr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import unittest
4+
35
from expecttest import TestCase
46
import torch
57

@@ -179,3 +181,7 @@ def _fn_make_precompiler(x: torch.Tensor, s: hl.constexpr):
179181
from helion.runtime.precompile_shim import make_precompiler
180182
return make_precompiler(_fn_kernel)(x, out, out.stride(0), out.stride(1), x.stride(0), b, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
181183
)
184+
185+
186+
if __name__ == "__main__":
187+
unittest.main()

test/test_control_flow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import unittest
4+
35
from expecttest import TestCase
46
import torch
57

@@ -191,3 +193,7 @@ def _fn_make_precompiler(x):
191193
from helion.runtime.precompile_shim import make_precompiler
192194
return make_precompiler(_fn_kernel)(x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)""",
193195
)
196+
197+
198+
if __name__ == "__main__":
199+
unittest.main()

test/test_examples.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,3 +1093,7 @@ def _attention_make_precompiler(q_in: torch.Tensor, k_in: torch.Tensor, v_in: to
10931093
from helion.runtime.precompile_shim import make_precompiler
10941094
return make_precompiler(_attention_kernel)(q_view, k_view, v_view, out, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
10951095
)
1096+
1097+
1098+
if __name__ == "__main__":
1099+
unittest.main()

test/test_logging.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import unittest
4+
35
from expecttest import TestCase
46
import torch
57

@@ -46,3 +48,7 @@ def add(x, y):
4648
self.assertTrue(
4749
any("DEBUG:helion.runtime.kernel:Debug string:" in msg for msg in cm.output)
4850
)
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

test/test_loops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,7 @@ def _matmul_make_precompiler(x: torch.Tensor, y: torch.Tensor):
441441
from helion.runtime.precompile_shim import make_precompiler
442442
return make_precompiler(_matmul_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)""",
443443
)
444+
445+
446+
if __name__ == "__main__":
447+
unittest.main()

test/test_matmul.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,3 +637,7 @@ def matmul_static_shapes(x: torch.Tensor, y: torch.Tensor):
637637
_matmul_static_shapes_kernel[triton.cdiv(127, _BLOCK_SIZE_0) * triton.cdiv(127, _BLOCK_SIZE_1),](x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
638638
return out""",
639639
)
640+
641+
642+
if __name__ == "__main__":
643+
unittest.main()

test/test_reductions.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from typing import TYPE_CHECKING
4+
import unittest
45

56
from expecttest import TestCase
67
import torch
@@ -216,41 +217,41 @@ def test_mean(self):
216217
reduce_kernel.bind(args)._debug_str(),
217218
"""\
218219
def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32):
219-
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:46>)
220+
# Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:47>)
220221
# Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size')
221222
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
222223
n, _m = x.size()
223-
# Call: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:47>)
224+
# Call: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:48>)
224225
# Attribute: CallableType(_VariableFunctionsClass.empty) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty')
225226
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
226-
# List: SequenceType([SymIntType(s77)]) SourceOrigin(location=<SourceLocation test_reductions.py:48>)
227-
# Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:46>), key=0)
227+
# List: SequenceType([SymIntType(s77)]) SourceOrigin(location=<SourceLocation test_reductions.py:49>)
228+
# Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:47>), key=0)
228229
# Name: LiteralType(torch.float32) ArgumentOrigin(name='out_dtype')
229230
# Attribute: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
230231
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
231232
# For: loop_type=GRID
232233
out = torch.empty([n], dtype=out_dtype, device=x.device)
233-
# Call: IterType(TileIndexType(0)) SourceOrigin(location=<SourceLocation test_reductions.py:52>)
234+
# Call: IterType(TileIndexType(0)) SourceOrigin(location=<SourceLocation test_reductions.py:53>)
234235
# Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile')
235236
# Name: PythonModuleType(helion.language) GlobalOrigin(name='hl')
236-
# Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:46>), key=0)
237+
# Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:47>), key=0)
237238
for tile_n in hl.tile(n):
238-
# Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:53>)
239-
# Name: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:47>)
240-
# Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:52>)
241-
# Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:53>)
239+
# Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
240+
# Name: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:48>)
241+
# Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:53>)
242+
# Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
242243
# Name: CallableType(_VariableFunctionsClass.mean) ArgumentOrigin(name='fn')
243-
# Subscript: TensorType([block_size_0, x_size1], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:53>)
244+
# Subscript: TensorType([block_size_0, x_size1], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
244245
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
245-
# Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:52>)
246-
# Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=<SourceLocation test_reductions.py:53>)
247-
# UnaryOp: LiteralType(-1) DeviceOrigin(location=<SourceLocation test_reductions.py:53>)
248-
# Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation test_reductions.py:53>)
246+
# Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:53>)
247+
# Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
248+
# UnaryOp: LiteralType(-1) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
249+
# Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation test_reductions.py:54>)
249250
out[tile_n] = fn(x[tile_n, :], dim=-1)
250251
return out
251252
252253
def root_graph_0():
253-
# File: .../test_reductions.py:53 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
254+
# File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
254255
x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
255256
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
256257
load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)]); x = None
@@ -261,15 +262,15 @@ def root_graph_0():
261262
return None
262263
263264
def reduction_loop_1():
264-
# File: .../test_reductions.py:53 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
265+
# File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
265266
x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
266267
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
267268
load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)]); x = block_size_0 = None
268269
mean_extra: "f32[u0]" = helion_language__tracing_ops__inductor_lowering_extra([load]); load = None
269270
return [mean_extra]
270271
271272
def root_graph_2():
272-
# File: .../test_reductions.py:53 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
273+
# File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
273274
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
274275
_for_loop = helion_language__tracing_ops__for_loop(1, [])
275276
getitem: "f32[u0]" = _for_loop[0]; _for_loop = None
@@ -419,3 +420,7 @@ def _reduce_kernel_make_precompiler(x: torch.Tensor, fn: Callable[[torch.Tensor]
419420
from helion.runtime.precompile_shim import make_precompiler
420421
return make_precompiler(_reduce_kernel_kernel)(x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)""",
421422
)
423+
424+
425+
if __name__ == "__main__":
426+
unittest.main()

test/test_views.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import unittest
4+
35
from expecttest import TestCase
46
import torch
57

@@ -272,3 +274,7 @@ def fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
272274
)
273275
_code, result = code_and_output(fn, args)
274276
torch.testing.assert_close(result, args[0] + args[1])
277+
278+
279+
if __name__ == "__main__":
280+
unittest.main()

0 commit comments

Comments
 (0)