Skip to content

Commit 3ccadfd

Browse files
committed
Add support for multiple top level for loops
ghstack-source-id: 1f33fe9 Pull Request resolved: #52
1 parent fd740c1 commit 3ccadfd

16 files changed

+737
-116
lines changed

helion/_compiler/ast_extension.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ def __init__(
4949
_type_info: TypeInfo | None = None,
5050
_loop_type: LoopType = LoopType.UNSET,
5151
_is_kernel_call: bool = False,
52+
_root_id: int | None = None,
5253
**kwargs: object,
5354
) -> None:
5455
super().__init__(**kwargs)
5556
self._type_info: TypeInfo | None = _type_info
5657
self._location: SourceLocation = _location
5758
self._loop_type: LoopType = _loop_type
5859
self._is_kernel_call: bool = _is_kernel_call
60+
self._root_id: int | None = _root_id
5961

6062
def new(self, fields: dict[str, object]) -> ExtendedAST:
6163
result = self.__class__(
@@ -64,6 +66,7 @@ def new(self, fields: dict[str, object]) -> ExtendedAST:
6466
_type_info=self._type_info,
6567
_loop_type=self._loop_type,
6668
_is_kernel_call=self._is_kernel_call,
69+
_root_id=self._root_id,
6770
)
6871
return self._location.to_ast(result)
6972

helion/_compiler/ast_read_writes.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,34 @@
1010
_A = TypeVar("_A", bound=ast.AST)
1111

1212

13+
# TODO(oulgen): This visitor is extremely primitive, does not consider alpha renaming or scopes
1314
class _ReadWriteVisitor(ast.NodeVisitor):
1415
def __init__(self) -> None:
1516
super().__init__()
1617
self.rw = ReadWrites(collections.Counter(), collections.Counter())
1718

19+
def _update(self, name: str, ctx: ast.expr_context) -> None:
20+
if isinstance(ctx, ast.Load):
21+
self.rw.reads[name] += 1
22+
elif isinstance(ctx, ast.Store):
23+
self.rw.writes[name] += 1
24+
1825
def visit_Name(self, node: ast.Name) -> None:
19-
if isinstance(node.ctx, ast.Load):
20-
self.rw.reads[node.id] += 1
21-
elif isinstance(node.ctx, ast.Store):
22-
self.rw.writes[node.id] += 1
23-
self.generic_visit(node)
26+
self._update(node.id, node.ctx)
27+
28+
def visit_Subscript(self, node: ast.Subscript) -> None:
29+
if isinstance(node.value, ast.Name):
30+
self._update(node.value.id, node.ctx)
31+
else:
32+
self.visit(node.value)
33+
34+
def visit_For(self, node: ast.For) -> None:
35+
# Skip target
36+
self.visit(node.iter)
37+
for stmt in node.body:
38+
self.visit(stmt)
39+
for stmt in node.orelse:
40+
self.visit(stmt)
2441

2542

2643
class ReadWrites(typing.NamedTuple):

helion/_compiler/compile_environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .. import exc
2121
from ..language.constexpr import ConstExpr
2222
from .error_reporting import ErrorReporting
23+
from .loop_dependency_checker import LoopDependencyChecker
2324
from .variable_origin import BlockSizeOrigin
2425
from .variable_origin import Origin
2526

@@ -70,6 +71,7 @@ def __init__(self, device: torch.device, settings: Settings) -> None:
7071
collections.Counter()
7172
)
7273
self.specialized_vars: set[sympy.Symbol] = set()
74+
self.loop_dependency_checker = LoopDependencyChecker()
7375

7476
def add_kernel_tensor_size(self, sizes: Sequence[int | torch.SymInt]) -> None:
7577
from .tile_strategy import TileStrategy
@@ -280,6 +282,7 @@ def __enter__(self) -> Self:
280282
self.fake_mode.__enter__()
281283
tls.env = self
282284
self.errors = ErrorReporting(self.settings) # clear prior errors
285+
self.loop_dependency_checker = LoopDependencyChecker()
283286
return self
284287

285288
def __exit__(

helion/_compiler/device_function.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636

3737
if TYPE_CHECKING:
3838
from ..runtime.config import Config
39+
from .program_id import ProgramIDs
40+
from .program_id import SharedProgramID
3941

4042
_P = TypeVar("_P", bound="TensorPropertyArg")
4143

@@ -145,7 +147,7 @@ def __init__(self, name: str, config: Config) -> None:
145147
self._unique_counter: dict[str, itertools.count[int]] = defaultdict(
146148
itertools.count
147149
)
148-
self.grid_expr: ast.AST | None = None
150+
self.pid: SharedProgramID | ProgramIDs | None = None
149151
self.namespace: _Namespace = _Namespace()
150152
self.namespace._used_names.update(reserved_names())
151153
self._variable_renames: dict[str, list[str]] = {}
@@ -169,9 +171,9 @@ def merge_variable_names(self, a: str, b: str) -> None:
169171
for n in name_group:
170172
self._variable_renames[n] = name_group
171173

172-
def set_grid_expr(self, grid_expr: ast.AST) -> None:
173-
assert self.grid_expr is None, "grid_expr already set"
174-
self.grid_expr = grid_expr
174+
def set_pid(self, pid: SharedProgramID | ProgramIDs) -> None:
175+
assert self.pid is None, "pid already set"
176+
self.pid = pid
175177

176178
def sympy_expr(self, expr: sympy.Expr) -> str:
177179
expr_to_origin = HostFunction.current().expr_to_origin
@@ -341,12 +343,12 @@ def codegen_function_call(self) -> ast.AST:
341343
f"num_stages={self.config.num_stages}",
342344
]
343345
)
344-
grid_expr = self.grid_expr
345-
assert grid_expr is not None
346+
pid = self.pid
347+
assert pid is not None
346348
# TODO(jansel): we should run CSE this statement
347349
call_statement = statement_from_string(
348350
f"{self.name}[__call_grid_expr]({', '.join(args)})",
349-
__call_grid_expr=grid_expr,
351+
__call_grid_expr=pid.codegen_grid(),
350352
)
351353
assert isinstance(call_statement, ExtendedAST)
352354
# Mark the kernel call we can find it in codegen_precompile_def

helion/_compiler/device_ir.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,22 +249,22 @@ class DeviceIR:
249249
def __init__(self) -> None:
250250
super().__init__()
251251
self.graphs: list[GraphInfo] = []
252-
self.root_id: int | None = None
252+
self.root_ids: list[int] = []
253253
self.rolled_reductions: list[RolledReductionInfo] = []
254254
self.grid_block_ids: list[list[int]] = []
255255

256-
def get_root(self, config: Config) -> torch.fx.Graph:
256+
def get_root(self, config: Config, graph_id: int) -> torch.fx.Graph:
257257
""" " If we are using a rolled reduction, return the rolled reduction graph otherwise
258258
return the root graph."""
259-
if (root_id := self.root_id) is None:
260-
raise AssertionError("No root graph")
259+
if graph_id >= len(self.graphs):
260+
raise AssertionError("Invalid graph id")
261261
reduction_loops = config.reduction_loops
262262
if len(reduction_loops) > 1:
263263
raise NotImplementedError("Multiple reduction loops not implemented")
264264
if len(reduction_loops) == 0 or reduction_loops[0] is None:
265-
return self.graphs[root_id].graph
265+
return self.graphs[graph_id].graph
266266
for info in reversed(self.rolled_reductions):
267-
if info.original_graph_id == root_id:
267+
if info.original_graph_id == graph_id:
268268
assert info.new_graph_id is not None
269269
return self.graphs[info.new_graph_id].graph
270270
raise AssertionError("No rolled reduction graph found")
@@ -301,8 +301,7 @@ def add_reduction_loop_graph(
301301
)
302302

303303
def add_root_graph(self, graph: torch.fx.Graph) -> None:
304-
assert self.root_id is None
305-
self.root_id = self.add_graph(graph, graph_info_cls=RootGraphInfo)
304+
self.root_ids.append(self.add_graph(graph, graph_info_cls=RootGraphInfo))
306305

307306
def build_rolled_reductions(self) -> None:
308307
env = CompileEnvironment.current()

helion/_compiler/generate_ast.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .device_function import DeviceFunction
2020
from .inductor_lowering import CodegenState
2121
from .inductor_lowering import codegen_call_with_graph
22+
from .program_id import SharedProgramID
2223
from .variable_origin import ArgumentOrigin
2324

2425
if TYPE_CHECKING:
@@ -42,6 +43,7 @@ def __init__(self, func: HostFunction, config: Config) -> None:
4243
self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = (
4344
collections.defaultdict(list)
4445
)
46+
self.next_else_block: list[ast.AST] | None = None
4547

4648
def offset_var(self, block_idx: int) -> str:
4749
return self.active_device_loops[block_idx][-1].strategy.offset_var(block_idx)
@@ -54,7 +56,9 @@ def mask_var(self, block_idx: int) -> str | None:
5456
return loops[-1].strategy.mask_var(block_idx)
5557
return None
5658

57-
def add_statement(self, stmt: ast.AST | str) -> None:
59+
def add_statement(self, stmt: ast.AST | str | None) -> None:
60+
if stmt is None:
61+
return
5862
if isinstance(stmt, str):
5963
stmt = statement_from_string(stmt)
6064
self.statements_stack[-1].append(stmt)
@@ -134,13 +138,36 @@ def generic_visit(self, node: ast.AST) -> ast.AST:
134138
fields[field] = old_value
135139
return node.new(fields)
136140

137-
def visit_For(self, node: ast.For) -> ast.AST:
141+
def visit_For(self, node: ast.For) -> ast.AST | None:
138142
assert isinstance(node, ExtendedAST)
139143
if node._loop_type == LoopType.GRID:
140144
assert not node.orelse
145+
146+
if len(self.host_function.device_ir.root_ids) == 1:
147+
body = self.device_function.body
148+
else:
149+
assert len(self.host_function.device_ir.root_ids) > 1
150+
assert node._root_id is not None
151+
# Multiple top level for loops
152+
153+
if node._root_id == 0:
154+
self.device_function.set_pid(
155+
SharedProgramID(
156+
self.device_function.new_var("pid_shared", dce=False)
157+
)
158+
)
159+
self.device_function.body.append(
160+
self.device_function.pid.codegen_pid_init()
161+
)
162+
if node._root_id < len(self.host_function.device_ir.root_ids) - 1:
163+
body = []
164+
else:
165+
# This is the last top level for, dont emit more if statements
166+
assert self.next_else_block is not None
167+
body = self.next_else_block
141168
with (
142169
self.set_on_device(),
143-
self.set_statements(self.device_function.body),
170+
self.set_statements(body),
144171
):
145172
iter_node = node.iter
146173
assert isinstance(iter_node, ExtendedAST)
@@ -166,21 +193,44 @@ def visit_For(self, node: ast.For) -> ast.AST:
166193

167194
from .inductor_lowering import CodegenState
168195

169-
fn._codegen(
170-
CodegenState(
171-
self,
172-
fx_node=None,
173-
proxy_args=[*bound.arguments.values()],
174-
ast_args=None,
175-
),
196+
state = CodegenState(
197+
self,
198+
fx_node=None,
199+
proxy_args=[*bound.arguments.values()],
200+
ast_args=None,
176201
)
202+
203+
fn._codegen(state)
204+
assert node._root_id is not None
177205
codegen_call_with_graph(
178206
self,
179-
self.host_function.device_ir.get_root(self.device_function.config),
207+
self.host_function.device_ir.get_root(
208+
self.device_function.config,
209+
self.host_function.device_ir.root_ids[node._root_id],
210+
),
180211
[],
181212
)
213+
# If we are in a multi top level loop, for all loops except for the last one
214+
# emit ifthenelse blocks
215+
if node._root_id < len(self.host_function.device_ir.root_ids) - 1:
216+
block = (
217+
self.device_function.body
218+
if self.next_else_block is None
219+
else self.next_else_block
220+
)
221+
self.next_else_block = []
222+
block.append(
223+
create(
224+
ast.If,
225+
test=self.device_function.pid.codegen_test(state),
226+
body=body,
227+
orelse=self.next_else_block,
228+
)
229+
)
182230
self.device_function.dead_code_elimination()
183-
return self.device_function.codegen_function_call()
231+
if node._root_id == len(self.host_function.device_ir.root_ids) - 1:
232+
return self.device_function.codegen_function_call()
233+
return None
184234
return self.generic_visit(node)
185235

186236
def visit_Name(self, node: ast.Name) -> ast.AST:
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
import itertools
4+
from typing import TYPE_CHECKING
5+
6+
from .. import exc
7+
from .ast_read_writes import ReadWrites
8+
9+
if TYPE_CHECKING:
10+
import ast
11+
12+
13+
class LoopDependencyChecker:
14+
"""
15+
A class to check dependencies between top-level for loops in a Helion kernel.
16+
17+
This class tracks memory accesses (reads and writes) for each top-level for loop
18+
and raises an error if a later loop reads or writes to anything written in a
19+
previous loop.
20+
"""
21+
22+
def __init__(self) -> None:
23+
self.reads: set[str] = set()
24+
self.writes: set[str] = set()
25+
26+
def register_loop(self, loop_node: ast.For) -> None:
27+
rw = ReadWrites.from_list(loop_node.body)
28+
29+
self._check_dependencies(rw)
30+
31+
self.reads |= set(rw.reads)
32+
self.writes |= set(rw.writes)
33+
34+
def _check_dependencies(self, rw: ReadWrites) -> None:
35+
"""
36+
Check for dependencies between the current loop and previous loops.
37+
38+
Raises:
39+
exc.LoopDependencyError: If a dependency is detected
40+
"""
41+
for name in sorted(itertools.chain(rw.reads, rw.writes)):
42+
if name in self.writes:
43+
raise exc.LoopDependencyError(name)

0 commit comments

Comments
 (0)