1
1
from __future__ import annotations
2
2
3
3
from typing import TYPE_CHECKING
4
+ import unittest
4
5
5
6
from expecttest import TestCase
6
7
import torch
@@ -216,41 +217,41 @@ def test_mean(self):
216
217
reduce_kernel .bind (args )._debug_str (),
217
218
"""\
218
219
def reduce_kernel(x: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], out_dtype=torch.float32):
219
- # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:46 >)
220
+ # Call: SequenceType((SymIntType(s77), SymIntType(s27))) SourceOrigin(location=<SourceLocation test_reductions.py:47 >)
220
221
# Attribute: TensorAttributeType AttributeOrigin(value=ArgumentOrigin(name='x'), key='size')
221
222
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
222
223
n, _m = x.size()
223
- # Call: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:47 >)
224
+ # Call: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:48 >)
224
225
# Attribute: CallableType(_VariableFunctionsClass.empty) AttributeOrigin(value=GlobalOrigin(name='torch'), key='empty')
225
226
# Name: PythonModuleType(torch) GlobalOrigin(name='torch')
226
- # List: SequenceType([SymIntType(s77)]) SourceOrigin(location=<SourceLocation test_reductions.py:48 >)
227
- # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:46 >), key=0)
227
+ # List: SequenceType([SymIntType(s77)]) SourceOrigin(location=<SourceLocation test_reductions.py:49 >)
228
+ # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:47 >), key=0)
228
229
# Name: LiteralType(torch.float32) ArgumentOrigin(name='out_dtype')
229
230
# Attribute: LiteralType(device(type='cuda', index=0)) AttributeOrigin(value=ArgumentOrigin(name='x'), key='device')
230
231
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
231
232
# For: loop_type=GRID
232
233
out = torch.empty([n], dtype=out_dtype, device=x.device)
233
- # Call: IterType(TileIndexType(0)) SourceOrigin(location=<SourceLocation test_reductions.py:52 >)
234
+ # Call: IterType(TileIndexType(0)) SourceOrigin(location=<SourceLocation test_reductions.py:53 >)
234
235
# Attribute: CallableType(tile) AttributeOrigin(value=GlobalOrigin(name='hl'), key='tile')
235
236
# Name: PythonModuleType(helion.language) GlobalOrigin(name='hl')
236
- # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:46 >), key=0)
237
+ # Name: SymIntType(s77) GetItemOrigin(value=SourceOrigin(location=<SourceLocation test_reductions.py:47 >), key=0)
237
238
for tile_n in hl.tile(n):
238
- # Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:53 >)
239
- # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:47 >)
240
- # Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:52 >)
241
- # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:53 >)
239
+ # Subscript: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
240
+ # Name: TensorType([x_size0], torch.float32) SourceOrigin(location=<SourceLocation test_reductions.py:48 >)
241
+ # Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:53 >)
242
+ # Call: TensorType([block_size_0], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
242
243
# Name: CallableType(_VariableFunctionsClass.mean) ArgumentOrigin(name='fn')
243
- # Subscript: TensorType([block_size_0, x_size1], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:53 >)
244
+ # Subscript: TensorType([block_size_0, x_size1], torch.float32) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
244
245
# Name: TensorType([x_size0, x_size1], torch.float32) ArgumentOrigin(name='x')
245
- # Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:52 >)
246
- # Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=<SourceLocation test_reductions.py:53 >)
247
- # UnaryOp: LiteralType(-1) DeviceOrigin(location=<SourceLocation test_reductions.py:53 >)
248
- # Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation test_reductions.py:53 >)
246
+ # Name: TileIndexType(0) SourceOrigin(location=<SourceLocation test_reductions.py:53 >)
247
+ # Slice: SliceType(LiteralType(None):LiteralType(None):LiteralType(None)) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
248
+ # UnaryOp: LiteralType(-1) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
249
+ # Constant: LiteralType(1) DeviceOrigin(location=<SourceLocation test_reductions.py:54 >)
249
250
out[tile_n] = fn(x[tile_n, :], dim=-1)
250
251
return out
251
252
252
253
def root_graph_0():
253
- # File: .../test_reductions.py:53 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
254
+ # File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
254
255
x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
255
256
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
256
257
load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)]); x = None
@@ -261,15 +262,15 @@ def root_graph_0():
261
262
return None
262
263
263
264
def reduction_loop_1():
264
- # File: .../test_reductions.py:53 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
265
+ # File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
265
266
x: "f32[s77, s27]" = helion_language__tracing_ops__host_tensor('x')
266
267
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
267
268
load: "f32[u0, u1]" = helion_language_memory_ops_load(x, [block_size_0, slice(None, None, None)]); x = block_size_0 = None
268
269
mean_extra: "f32[u0]" = helion_language__tracing_ops__inductor_lowering_extra([load]); load = None
269
270
return [mean_extra]
270
271
271
272
def root_graph_2():
272
- # File: .../test_reductions.py:53 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
273
+ # File: .../test_reductions.py:54 in reduce_kernel, code: out[tile_n] = fn(x[tile_n, :], dim=-1)
273
274
block_size_0: "Sym(u0)" = helion_language__tracing_ops__get_symnode('block_size_0')
274
275
_for_loop = helion_language__tracing_ops__for_loop(1, [])
275
276
getitem: "f32[u0]" = _for_loop[0]; _for_loop = None
@@ -419,3 +420,7 @@ def _reduce_kernel_make_precompiler(x: torch.Tensor, fn: Callable[[torch.Tensor]
419
420
from helion.runtime.precompile_shim import make_precompiler
420
421
return make_precompiler(_reduce_kernel_kernel)(x, out, out.size(0), x.size(0), x.size(1), out.stride(0), x.stride(0), x.stride(1), _m, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)""" ,
421
422
)
423
+
424
+
425
+ if __name__ == "__main__" :
426
+ unittest .main ()
0 commit comments