Skip to content

Add support for multiple top level for loops #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 8, 2025

Conversation

oulgen
Copy link
Contributor

@oulgen oulgen commented May 17, 2025

oulgen added a commit that referenced this pull request May 17, 2025
ghstack-source-id: f1837ad
Pull Request resolved: #52
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 17, 2025
oulgen added a commit that referenced this pull request May 18, 2025
ghstack-source-id: 056d55f
Pull Request resolved: #52
@oulgen oulgen marked this pull request as ready for review May 18, 2025 02:43
@oulgen oulgen requested review from jansel, yf225 and drisspg May 18, 2025 02:44
x0[tile] += c0
for tile in hl.tile(x1.size()):
x1[tile] += c1
for tile in hl.tile(x2.size()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is codgened if one of the other loops happens to have the same iteration bounds as a previous?

Copy link
Contributor Author

@oulgen oulgen May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still the same thing, i didnt do any fusion, that can be a good next step for me to do as a follow up

@@ -170,7 +173,9 @@ def merge_variable_names(self, a: str, b: str) -> None:
self._variable_renames[n] = name_group

def set_grid_expr(self, grid_expr: ast.AST) -> None:
assert self.grid_expr is None, "grid_expr already set"
if not self.shared_pid:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refactor the existing grid handling to reuse some of the new more general code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "more general code"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old stuff uses self.grid_expr which only supports one grid.

The new stuff supports a list of grids, and self.grid_expr will be set to the last grid (and incorrect if ever used with multiple grids active).

This seems like duplicate logic, where we should delete self.grid_expr and use the new structure.

raise AssertionError("No root graph")
if root_id >= len(self.root_ids):
raise AssertionError("Invalid root graph")
rid = self.root_ids[root_id]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this extra layer of indirection actually needed? Could just use the graph ID directly?

from .program_id import VirtualProgramIDs
from .variable_origin import BlockSizeOrigin
from helion import exc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from helion import exc
from .. import exc

Comment on lines 386 to 387
# TODO(oulgen): Support this
raise exc.MultipleDeviceLoopBlocks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could support this by having the other 1D PID strategies (all but one of them) take shared_pid_var as an arg (which defaults to tl.program_id(0)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of doing this on a follow up PR, would you prefer me to do it here?

def select_pid_strategy(self) -> ProgramIDs:
def select_pid_strategy(self, state: CodegenState) -> ProgramIDs:
if (shared_pid := state.device_function.shared_pid) is not None:
return shared_pid
if self.l2_grouping > 1:
return L2GroupingProgramIDs(group_size=self.l2_grouping)
if 1 < len(self.block_indices) <= 3 and self.fn.config.use_yz_grid:
Copy link
Contributor

@jansel jansel May 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we updating the ConfigSpec so use_yz_grid is removed from the config?

helion/exc.py Outdated
class MultipleDeviceLoops(BaseError):
message = "Multiple grid loops are not allowed. Support for this may be added in the future."
class MultipleDeviceLoopBlocks(BaseError):
message = "Multiple blocks for multiple top level grid loops are not yet allowed. Support for this may be added in the future."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should support this. It should be a small change to VirtualProgramIDs to use shared_pid_var.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto on the follow up PR, but I can do it here if you prefer that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both this and the above, if it is easy let's do it here. If it takes more work a followup is fine.

def select_pid_strategy(self) -> ProgramIDs:
def select_pid_strategy(self, state: CodegenState) -> ProgramIDs:
if (shared_pid := state.device_function.shared_pid) is not None:
return shared_pid
if self.l2_grouping > 1:
return L2GroupingProgramIDs(group_size=self.l2_grouping)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with the l2 grouping stuff? (It should be possible since L2GroupingProgramIDs only uses a 1D pid.)

@jansel
Copy link
Contributor

jansel commented May 24, 2025

One other thought: we should introduce a check for data dependency between the loops. If loop 1 writes to X, loop 2 cannot read from X -- and we should error in that case.

@oulgen
Copy link
Contributor Author

oulgen commented May 24, 2025

One other thought: we should introduce a check for data dependency between the loops. If loop 1 writes to X, loop 2 cannot read from X -- and we should error in that case.

Sounds good. I'll come back and fix this PR once i finish the fbcode demo.

oulgen added a commit that referenced this pull request Jun 8, 2025
ghstack-source-id: 1e31557
Pull Request resolved: #52
@oulgen oulgen requested a review from jansel June 8, 2025 06:43
@oulgen
Copy link
Contributor Author

oulgen commented Jun 8, 2025

@jansel this is ready for review now. Since the last time you looked, I have

  • Added dependency checker to make sure latter loop does not read/write modified values of earlier loops
  • There are no statements between 2 loops
  • Completely revamped grid_exp system
  • shared pid is now a separate concept from pids of individual loops

@@ -145,7 +147,7 @@ def __init__(self, name: str, config: Config) -> None:
self._unique_counter: dict[str, itertools.count[int]] = defaultdict(
itertools.count
)
self.grid_expr: ast.AST | None = None
self.pid: SharedProgramID | ProgramIDs | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: could introduce a base class here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure how this would work. In truth, this is the Either type from haskell etc where pid is either sharedpid that contains bunch of pids for multiple for loops, or a single pid for a single for loop.

A base class could contain codegen_grid but i also use this as device_function.pid to check/access the sharedpid in compiler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to skip this change then.

Raises:
exc.LoopDependencyError: If a dependency is detected
"""
for name in itertools.chain(rw.reads, rw.writes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sets are non-determistic

Suggested change
for name in itertools.chain(rw.reads, rw.writes):
for name in sorted(itertools.chain(rw.reads, rw.writes)):

oulgen added a commit that referenced this pull request Jun 8, 2025
ghstack-source-id: 1f33fe9
Pull Request resolved: #52
@oulgen oulgen merged commit 528e5ee into gh/oulgen/5/base Jun 8, 2025
6 checks passed
oulgen added a commit that referenced this pull request Jun 8, 2025
ghstack-source-id: 1f33fe9
Pull Request resolved: #52
@oulgen oulgen deleted the gh/oulgen/5/head branch June 8, 2025 18:13
@yf225
Copy link
Contributor

yf225 commented Jun 9, 2025

@oulgen I suspect it might be causing an error on pytest test/test_examples.py::TestExamples::test_moe_matmul_ogs (which is not run on CI right now):

======================================================================
ERROR: test_moe_matmul_ogs (__main__.TestExamples.test_moe_matmul_ogs)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 221, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 680, in visit_Name
    assert type_info.origin.is_host()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/willfeng/local/helion/test/test_examples.py", line 1651, in test_moe_matmul_ogs
    run_example(
  File "/home/willfeng/local/helion/test/test_examples.py", line 34, in run_example
    code, result = code_and_output(
                   ^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_testing.py", line 45, in code_and_output
    code = fn.bind(args).to_triton_code(config)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 128, in bind
    bound_kernel = BoundKernel(self, args)
                   ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/runtime/kernel.py", line 288, in __init__
    self.host_function: HostFunction = HostFunction(
                                       ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/host_function.py", line 108, in __init__
    self.device_ir = lower_to_device_ir(self)
                     ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 869, in lower_to_device_ir
    visitor.visit(stmt)
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 221, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 850, in visit_For
    _make_fx(lambda: WalkDeviceAST(self.device_ir).visit(node))
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 124, in _make_fx
    return proxy_tensor.make_fx(fn, decomposition_table=select_decomp_table())(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2290, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2228, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_dynamo/eval_frame.py", line 893, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 850, in <lambda>
    _make_fx(lambda: WalkDeviceAST(self.device_ir).visit(node))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 221, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 509, in visit_For
    self._body(node.body)
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 418, in _body
    self.visit(stmt)
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 221, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 605, in visit_If
    self._create_if_subgraph(test_proxy, node.body)
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 639, in _create_if_subgraph
    body_graph = proxy_tensor.make_fx(
                 ^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2290, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2228, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_dynamo/eval_frame.py", line 893, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 627, in run_body
    subgraph_walker._body(body)
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 418, in _body
    self.visit(stmt)
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 221, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 556, in visit_For
    graph = proxy_tensor.make_fx(
            ^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2290, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2228, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_dynamo/eval_frame.py", line 893, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 543, in run_subgraph
    subgraph_walker._body(node.body)
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 418, in _body
    self.visit(stmt)
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 221, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 556, in visit_For
    graph = proxy_tensor.make_fx(
            ^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2290, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2228, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/_dynamo/eval_frame.py", line 893, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/willfeng/local/pytorch-nightly/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 543, in run_subgraph
    subgraph_walker._body(node.body)
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 418, in _body
    self.visit(stmt)
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 221, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 720, in visit_Assign
    self._assign(target, self.visit(node.value))
                         ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 221, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 773, in visit_Subscript
    return hl.load(self.visit(value), self._subscript_slice_proxy(node.slice))
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 688, in _subscript_slice_proxy
    result = self.visit(slice_node)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 221, in visit
    return visitor(node)
           ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/device_ir.py", line 694, in visit_Tuple
    return tuple([self.visit(x) for x in node.elts])
                  ^^^^^^^^^^^^^
  File "/home/willfeng/local/helion/helion/_compiler/ast_extension.py", line 225, in visit
    raise exc.InternalError(e) from e
helion.exc.InternalError: AssertionError: 
While processing:
  File "/home/willfeng/local/helion/examples/moe_matmul_ogs.py", line 74, in moe_matmul_ogs
    A_frag = A[expert_orig_token_indices, tile_k]  # [BLOCK_T, BLOCK_K]

@oulgen
Copy link
Contributor Author

oulgen commented Jun 9, 2025

#148 should fix it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants