-
Notifications
You must be signed in to change notification settings - Fork 13
Add support for multiple top level for loops #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
x0[tile] += c0 | ||
for tile in hl.tile(x1.size()): | ||
x1[tile] += c1 | ||
for tile in hl.tile(x2.size()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is codgened if one of the other loops happens to have the same iteration bounds as a previous?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still the same thing, i didnt do any fusion, that can be a good next step for me to do as a follow up
helion/_compiler/device_function.py
Outdated
@@ -170,7 +173,9 @@ def merge_variable_names(self, a: str, b: str) -> None: | |||
self._variable_renames[n] = name_group | |||
|
|||
def set_grid_expr(self, grid_expr: ast.AST) -> None: | |||
assert self.grid_expr is None, "grid_expr already set" | |||
if not self.shared_pid: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we refactor the existing grid handling to reuse some of the new more general code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "more general code"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old stuff uses self.grid_expr
which only supports one grid.
The new stuff supports a list of grids, and self.grid_expr
will be set to the last grid (and incorrect if ever used with multiple grids active).
This seems like duplicate logic, where we should delete self.grid_expr
and use the new structure.
helion/_compiler/device_ir.py
Outdated
raise AssertionError("No root graph") | ||
if root_id >= len(self.root_ids): | ||
raise AssertionError("Invalid root graph") | ||
rid = self.root_ids[root_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this extra layer of indirection actually needed? Could just use the graph ID directly?
helion/_compiler/tile_strategy.py
Outdated
from .program_id import VirtualProgramIDs | ||
from .variable_origin import BlockSizeOrigin | ||
from helion import exc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from helion import exc | |
from .. import exc |
helion/_compiler/tile_strategy.py
Outdated
# TODO(oulgen): Support this | ||
raise exc.MultipleDeviceLoopBlocks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could support this by having the other 1D PID strategies (all but one of them) take shared_pid_var
as an arg (which defaults to tl.program_id(0)
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking of doing this on a follow up PR, would you prefer me to do it here?
helion/_compiler/tile_strategy.py
Outdated
def select_pid_strategy(self) -> ProgramIDs: | ||
def select_pid_strategy(self, state: CodegenState) -> ProgramIDs: | ||
if (shared_pid := state.device_function.shared_pid) is not None: | ||
return shared_pid | ||
if self.l2_grouping > 1: | ||
return L2GroupingProgramIDs(group_size=self.l2_grouping) | ||
if 1 < len(self.block_indices) <= 3 and self.fn.config.use_yz_grid: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we updating the ConfigSpec so use_yz_grid
is removed from the config?
helion/exc.py
Outdated
class MultipleDeviceLoops(BaseError): | ||
message = "Multiple grid loops are not allowed. Support for this may be added in the future." | ||
class MultipleDeviceLoopBlocks(BaseError): | ||
message = "Multiple blocks for multiple top level grid loops are not yet allowed. Support for this may be added in the future." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should support this. It should be a small change to VirtualProgramIDs
to use shared_pid_var
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto on the follow up PR, but I can do it here if you prefer that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For both this and the above, if it is easy let's do it here. If it takes more work a followup is fine.
helion/_compiler/tile_strategy.py
Outdated
def select_pid_strategy(self) -> ProgramIDs: | ||
def select_pid_strategy(self, state: CodegenState) -> ProgramIDs: | ||
if (shared_pid := state.device_function.shared_pid) is not None: | ||
return shared_pid | ||
if self.l2_grouping > 1: | ||
return L2GroupingProgramIDs(group_size=self.l2_grouping) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work with the l2 grouping stuff? (It should be possible since L2GroupingProgramIDs only uses a 1D pid.)
One other thought: we should introduce a check for data dependency between the loops. If loop 1 writes to X, loop 2 cannot read from X -- and we should error in that case. |
Sounds good. I'll come back and fix this PR once i finish the fbcode demo. |
[ghstack-poisoned]
@jansel this is ready for review now. Since the last time you looked, I have
|
@@ -145,7 +147,7 @@ def __init__(self, name: str, config: Config) -> None: | |||
self._unique_counter: dict[str, itertools.count[int]] = defaultdict( | |||
itertools.count | |||
) | |||
self.grid_expr: ast.AST | None = None | |||
self.pid: SharedProgramID | ProgramIDs | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: could introduce a base class here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not sure how this would work. In truth, this is the Either type from haskell etc where pid is either sharedpid that contains bunch of pids for multiple for loops, or a single pid for a single for loop.
A base class could contain codegen_grid but i also use this as device_function.pid to check/access the sharedpid in compiler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to skip this change then.
Raises: | ||
exc.LoopDependencyError: If a dependency is detected | ||
""" | ||
for name in itertools.chain(rw.reads, rw.writes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sets are non-determistic
for name in itertools.chain(rw.reads, rw.writes): | |
for name in sorted(itertools.chain(rw.reads, rw.writes)): |
[ghstack-poisoned]
@oulgen I suspect it might be causing an error on
|
#148 should fix it |
Stack from ghstack (oldest at bottom):