Skip to content

Commit fc21a87

Browse files
committed
Add support for multiple top level for loops
ghstack-source-id: 056d55f Pull Request resolved: #52
1 parent 59c707c commit fc21a87

File tree

9 files changed

+379
-29
lines changed

9 files changed

+379
-29
lines changed

helion/_compiler/ast_extension.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ def __init__(
4949
_type_info: TypeInfo | None = None,
5050
_loop_type: LoopType = LoopType.UNSET,
5151
_is_kernel_call: bool = False,
52+
_root_id: int | None = None,
5253
**kwargs: object,
5354
) -> None:
5455
super().__init__(**kwargs)
5556
self._type_info: TypeInfo | None = _type_info
5657
self._location: SourceLocation = _location
5758
self._loop_type: LoopType = _loop_type
5859
self._is_kernel_call: bool = _is_kernel_call
60+
self._root_id: int | None = _root_id
5961

6062
def new(self, fields: dict[str, object]) -> ExtendedAST:
6163
result = self.__class__(
@@ -64,6 +66,7 @@ def new(self, fields: dict[str, object]) -> ExtendedAST:
6466
_type_info=self._type_info,
6567
_loop_type=self._loop_type,
6668
_is_kernel_call=self._is_kernel_call,
69+
_root_id=self._root_id,
6770
)
6871
return self._location.to_ast(result)
6972

helion/_compiler/device_function.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
if TYPE_CHECKING:
3838
from ..runtime.config import Config
39+
from .program_id import SharedProgramIDs
3940

4041
_P = TypeVar("_P", bound="TensorPropertyArg")
4142

@@ -158,6 +159,8 @@ def __init__(self, name: str, config: Config) -> None:
158159
self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config)
159160
self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config)
160161

162+
self.shared_pid: SharedProgramIDs | None = None
163+
161164
def block_size_var(self, block_size_idx: int) -> str | None:
162165
return self.block_size_var_cache.get((block_size_idx,))
163166

@@ -170,7 +173,9 @@ def merge_variable_names(self, a: str, b: str) -> None:
170173
self._variable_renames[n] = name_group
171174

172175
def set_grid_expr(self, grid_expr: ast.AST) -> None:
173-
assert self.grid_expr is None, "grid_expr already set"
176+
if not self.shared_pid:
177+
# For shared pid, its OK to set grid expr multiple times, just use the last one
178+
assert self.grid_expr is None, "grid_expr already set"
174179
self.grid_expr = grid_expr
175180

176181
def sympy_expr(self, expr: sympy.Expr) -> str:

helion/_compiler/device_ir.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,22 +211,23 @@ class DeviceIR:
211211
def __init__(self) -> None:
212212
super().__init__()
213213
self.graphs: list[GraphInfo] = []
214-
self.root_id: int | None = None
214+
self.root_ids: list[int] = []
215215
self.rolled_reductions: list[RolledReductionInfo] = []
216216
self.grid_block_indices: list[list[int]] = []
217217

218-
def get_root(self, config: Config) -> torch.fx.GraphModule:
218+
def get_root(self, config: Config, root_id: int) -> torch.fx.GraphModule:
219219
""" " If we are using a rolled reduction, return the rolled reduction graph otherwise
220220
return the root graph."""
221-
if (root_id := self.root_id) is None:
222-
raise AssertionError("No root graph")
221+
if root_id >= len(self.root_ids):
222+
raise AssertionError("Invalid root graph")
223+
rid = self.root_ids[root_id]
223224
reduction_loops = config.reduction_loops
224225
if len(reduction_loops) > 1:
225226
raise NotImplementedError("Multiple reduction loops not implemented")
226227
if len(reduction_loops) == 0 or reduction_loops[0] is None:
227-
return self.graphs[root_id].graph
228+
return self.graphs[rid].graph
228229
for info in reversed(self.rolled_reductions):
229-
if info.original_graph_id == root_id:
230+
if info.original_graph_id == rid:
230231
assert info.new_graph_id is not None
231232
return self.graphs[info.new_graph_id].graph
232233
raise AssertionError("No rolled reduction graph found")
@@ -259,8 +260,7 @@ def add_reduction_loop_graph(
259260
)
260261

261262
def add_root_graph(self, graph: torch.fx.GraphModule) -> None:
262-
assert self.root_id is None
263-
self.root_id = self.add_graph(graph, graph_info_cls=RootGraphInfo)
263+
self.root_ids.append(self.add_graph(graph, graph_info_cls=RootGraphInfo))
264264

265265
def build_rolled_reductions(self) -> None:
266266
env = CompileEnvironment.current()

helion/_compiler/generate_ast.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .compile_environment import CompileEnvironment
1919
from .device_function import DeviceFunction
2020
from .inductor_lowering import codegen_call_with_graph
21+
from .program_id import SharedProgramIDs
2122
from .variable_origin import ArgumentOrigin
2223

2324
if TYPE_CHECKING:
@@ -41,6 +42,7 @@ def __init__(self, func: HostFunction, config: Config) -> None:
4142
self.active_device_loops: dict[int, list[DeviceLoopOrGridState]] = (
4243
collections.defaultdict(list)
4344
)
45+
self.next_else_block: list[ast.AST] | None = None
4446

4547
def offset_var(self, block_idx: int) -> str:
4648
return self.active_device_loops[block_idx][-1].strategy.offset_var(block_idx)
@@ -51,7 +53,9 @@ def index_var(self, block_idx: int) -> str:
5153
def mask_var(self, block_idx: int) -> str | None:
5254
return self.active_device_loops[block_idx][-1].strategy.mask_var(block_idx)
5355

54-
def add_statement(self, stmt: ast.AST | str) -> None:
56+
def add_statement(self, stmt: ast.AST | str | None) -> None:
57+
if stmt is None:
58+
return
5559
if isinstance(stmt, str):
5660
stmt = statement_from_string(stmt)
5761
self.statements_stack[-1].append(stmt)
@@ -131,13 +135,34 @@ def generic_visit(self, node: ast.AST) -> ast.AST:
131135
fields[field] = old_value
132136
return node.new(fields)
133137

134-
def visit_For(self, node: ast.For) -> ast.AST:
138+
def visit_For(self, node: ast.For) -> ast.AST | None:
135139
assert isinstance(node, ExtendedAST)
136140
if node._loop_type == LoopType.GRID:
137141
assert not node.orelse
142+
143+
if len(self.host_fn.device_ir.root_ids) == 1:
144+
body = self.device_function.body
145+
else:
146+
assert len(self.host_fn.device_ir.root_ids) > 1
147+
assert node._root_id is not None
148+
# Multiple top level for loops
149+
150+
if node._root_id == 0:
151+
self.device_function.shared_pid = SharedProgramIDs(
152+
self.device_function.new_var("pid_shared", dce=False)
153+
)
154+
self.device_function.body.append(
155+
self.device_function.shared_pid.codegen_pid_init()
156+
)
157+
if node._root_id < len(self.host_fn.device_ir.root_ids) - 1:
158+
body = []
159+
else:
160+
# This is the last top level for, dont emit more if statements
161+
assert self.next_else_block is not None
162+
body = self.next_else_block
138163
with (
139164
self.set_on_device(),
140-
self.set_statements(self.device_function.body),
165+
self.set_statements(body),
141166
):
142167
iter_node = node.iter
143168
assert isinstance(iter_node, ExtendedAST)
@@ -163,21 +188,43 @@ def visit_For(self, node: ast.For) -> ast.AST:
163188

164189
from .inductor_lowering import CodegenState
165190

166-
fn._codegen(
167-
CodegenState(
168-
self,
169-
fx_node=None,
170-
proxy_args=[*bound.arguments.values()],
171-
ast_args=None,
172-
),
191+
state = CodegenState(
192+
self,
193+
fx_node=None,
194+
proxy_args=[*bound.arguments.values()],
195+
ast_args=None,
173196
)
197+
198+
fn._codegen(state)
199+
assert node._root_id is not None
174200
codegen_call_with_graph(
175201
self,
176-
self.host_fn.device_ir.get_root(self.device_function.config),
202+
self.host_fn.device_ir.get_root(
203+
self.device_function.config, node._root_id
204+
),
177205
[],
178206
)
207+
# If we are in a multi top level loop, for all loops except for the last one
208+
# emit ifthenelse blocks
209+
if node._root_id < len(self.host_fn.device_ir.root_ids) - 1:
210+
block = (
211+
self.device_function.body
212+
if self.next_else_block is None
213+
else self.next_else_block
214+
)
215+
self.next_else_block = []
216+
block.append(
217+
create(
218+
ast.If,
219+
test=self.device_function.shared_pid.codegen_test(state),
220+
body=body,
221+
orelse=self.next_else_block,
222+
)
223+
)
179224
self.device_function.dead_code_elimination()
180-
return self.device_function.codegen_function_call()
225+
if node._root_id == len(self.host_fn.device_ir.root_ids) - 1:
226+
return self.device_function.codegen_function_call()
227+
return None
181228
return self.generic_visit(node)
182229

183230
def visit_Name(self, node: ast.Name) -> ast.AST:

helion/_compiler/program_id.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from helion._compiler.host_function import HostFunction
1010

1111
if TYPE_CHECKING:
12+
import ast
13+
1214
import sympy
1315

1416
from helion._compiler.inductor_lowering import CodegenState
@@ -57,6 +59,49 @@ def codegen(self, state: CodegenState) -> None:
5759
state.device_function.set_grid_expr(expr_from_string(f"({', '.join(grid)},)"))
5860

5961

62+
class SharedProgramIDs(ProgramIDs):
63+
"""
64+
Use the same PID for all blocks
65+
TODO(oulgen): Currently only supports 1 dimension
66+
"""
67+
68+
def __init__(self, shared_pid_var: str) -> None:
69+
super().__init__()
70+
self.shared_pid_var = shared_pid_var
71+
72+
def codegen_pid_init(
73+
self,
74+
) -> ast.stmt:
75+
return statement_from_string(f"{self.shared_pid_var} = tl.program_id(0)")
76+
77+
def codegen_test(self, state: CodegenState) -> ast.AST:
78+
blocks = []
79+
for pid in self.pids:
80+
blocks.append(pid.device_cdiv(state))
81+
82+
assert len(blocks) > 0
83+
return expr_from_string(f"{self.shared_pid_var} < ({'+ '.join(blocks)})")
84+
85+
def codegen(self, state: CodegenState) -> None:
86+
# TODO(oulgen): We need CSE between codegen_test and codegen for shared device cdivs
87+
blocks = []
88+
for pid in self.pids[:-1]:
89+
blocks.append(pid.device_cdiv(state))
90+
91+
if blocks:
92+
state.codegen.statements_stack[-1].insert(
93+
0,
94+
statement_from_string(
95+
f"{self.shared_pid_var} -= ({'+ '.join(blocks)})"
96+
),
97+
)
98+
99+
grid = []
100+
for pid in self.pids:
101+
grid.append(pid.host_cdiv())
102+
state.device_function.set_grid_expr(expr_from_string(f"({'+ '.join(grid)},)"))
103+
104+
60105
class VirtualProgramIDs(ProgramIDs):
61106
"""Only use the x grid and compute other dimensions"""
62107

helion/_compiler/tile_strategy.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
from .program_id import L2GroupingProgramIDs
2525
from .program_id import ProgramID
2626
from .program_id import ProgramIDs
27+
from .program_id import SharedProgramIDs
2728
from .program_id import VirtualProgramIDs
2829
from .variable_origin import BlockSizeOrigin
30+
from helion import exc
2931

3032
if TYPE_CHECKING:
3133
from collections.abc import Sequence
@@ -379,14 +381,20 @@ def codegen_grid(self, state: CodegenState) -> None:
379381
dtype = env.triton_index_type()
380382
block_sizes = self.block_size
381383
assert len(block_sizes) == len(block_indices)
382-
pids = self.select_pid_strategy()
384+
pids = self.select_pid_strategy(state)
385+
if isinstance(pids, SharedProgramIDs) and len(block_sizes) > 1:
386+
# TODO(oulgen): Support this
387+
raise exc.MultipleDeviceLoopBlocks
383388
for i, (block_idx, block_size) in enumerate(
384389
reversed(self._reorder([*zip(block_indices, block_sizes, strict=True)]))
385390
):
386391
numel = env.block_sizes[block_idx].numel
387392
offset_var = self.offset_var(block_idx)
388393
index_var = self.index_var(block_idx)
389-
pid_var = device_fn.new_var(f"pid_{i}", dce=True)
394+
if isinstance(pids, SharedProgramIDs):
395+
pid_var = pids.shared_pid_var
396+
else:
397+
pid_var = device_fn.new_var(f"pid_{i}", dce=True)
390398
if block_size != 1:
391399
block_size_var = self.block_size_var(block_idx)
392400
assert block_size_var is not None
@@ -433,7 +441,9 @@ def _setup_mask(
433441
f"{mask_var} = ({index_var} < ({state.device_function.sympy_expr(numel)}))"
434442
)
435443

436-
def select_pid_strategy(self) -> ProgramIDs:
444+
def select_pid_strategy(self, state: CodegenState) -> ProgramIDs:
445+
if (shared_pid := state.device_function.shared_pid) is not None:
446+
return shared_pid
437447
if self.l2_grouping > 1:
438448
return L2GroupingProgramIDs(group_size=self.l2_grouping)
439449
if 1 < len(self.block_indices) <= 3 and self.fn.config.use_yz_grid:

helion/_compiler/type_propagation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,12 +1901,11 @@ def visit_For(self, node: ast.For) -> TypeInfo:
19011901
if node.orelse:
19021902
raise exc.DeviceLoopElseBlock(fn.__qualname__)
19031903

1904-
self.device_loop_count += 1
19051904
if self.device_loop_depth == 0:
19061905
self.func.set_local_types(parent_scope.extract_locals())
19071906
node._loop_type = LoopType.GRID
1908-
if self.device_loop_count != 1:
1909-
raise exc.MultipleDeviceLoops
1907+
node._root_id = self.device_loop_count
1908+
self.device_loop_count += 1
19101909
if len(ExtendedAST.current()) != 1:
19111910
raise exc.NestedGridLoop
19121911

helion/exc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ class DeviceLoopElseBlock(BaseError):
7373
message = "for...else block is not allowed in a {0} device loop."
7474

7575

76-
class MultipleDeviceLoops(BaseError):
77-
message = "Multiple grid loops are not allowed. Support for this may be added in the future."
76+
class MultipleDeviceLoopBlocks(BaseError):
77+
message = "Multiple blocks for multiple top level grid loops are not yet allowed. Support for this may be added in the future."
7878

7979

8080
class NestedGridLoop(BaseError):

0 commit comments

Comments
 (0)