Skip to content

Add hl.register_block_size and explicit tile sizes #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ def softmax_decomposed(x: torch.Tensor) -> torch.Tensor:
return out


# This optimization does softmax in fewer passes, but is less numerically stable
@helion.kernel(config={"block_sizes": [1, 128]})
def softmax_two_pass(x: torch.Tensor) -> torch.Tensor:
m, n = x.size()
out = torch.empty_like(x)
block_size_m = hl.register_block_size(m)
block_size_n = hl.register_block_size(n)
for tile_m in hl.tile(m, block_size=block_size_m):
mi = hl.full([tile_m, 1], float("-inf"), dtype=torch.float32)
di = hl.zeros([tile_m, block_size_n], dtype=torch.float32)
for tile_n in hl.tile(n, block_size=block_size_n):
values = x[tile_m, tile_n]
local_amax = torch.amax(values, dim=1, keepdim=True)
mi_next = torch.maximum(mi, local_amax)
di = di * torch.exp(mi - mi_next) + torch.exp(values - mi_next)
mi = mi_next
for tile_n in hl.tile(n, block_size=block_size_n):
values = x[tile_m, tile_n]
out[tile_m, tile_n] = torch.exp(values - mi) / di
return out


def check(m: int, n: int) -> None:
from triton.testing import do_bench

Expand Down
107 changes: 103 additions & 4 deletions helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections
import dataclasses
import threading
import types
import typing
Expand All @@ -26,6 +27,7 @@

from torch._guards import Source

from .. import Config
from .. import exc
from ..runtime.settings import Settings

Expand Down Expand Up @@ -78,7 +80,11 @@ def finalize_config_spec(self) -> None:
)

def allocate_block_size(
self, size: int | torch.SymInt, *, reduction: bool = False
self,
size: int | torch.SymInt,
*,
reduction: bool = False,
source: BlockSizeSource,
) -> int:
idx = len(self.block_sizes)
self.block_sizes.append(
Expand All @@ -89,6 +95,7 @@ def allocate_block_size(
f"block_size_{idx}" if not reduction else f"rdim_{idx}"
),
reduction=reduction,
block_size_source=source,
)
)

Expand All @@ -104,7 +111,13 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
for rdim in self.block_sizes:
if rdim.reduction and rdim.size == size:
return rdim
rdim_idx = self.allocate_block_size(size, reduction=True)
rdim_idx = self.allocate_block_size(
size,
reduction=True,
source=ReductionLoopBlockSizeSource(
sum([int(bs.reduction) for bs in self.block_sizes])
),
)
return self.block_sizes[rdim_idx]

def create_block_var(self, debug_name: str) -> torch.SymInt:
Expand Down Expand Up @@ -196,8 +209,8 @@ def known_equal(self, a: int | torch.SymInt, b: int | torch.SymInt) -> bool:
return bool(res)
return a == b

def known_multiple(self, a: sympy.Expr, b: int) -> bool:
if isinstance(a, (int, sympy.Integer)):
def known_multiple(self, a: sympy.Expr, b: int | torch.SymInt) -> bool:
if isinstance(a, (int, sympy.Integer)) and isinstance(b, int):
return (int(a) % b) == 0
return False

Expand Down Expand Up @@ -257,6 +270,7 @@ class BlockSizeInfo(typing.NamedTuple):
size: torch.SymInt | int
var: torch.SymInt
reduction: bool
block_size_source: BlockSizeSource

@property
def numel(self) -> sympy.Expr:
Expand All @@ -265,6 +279,91 @@ def numel(self) -> sympy.Expr:
def symbol(self) -> sympy.Symbol:
return self.var._sympy_()

def from_config(self, config: Config) -> int | torch.SymInt | None:
return self.block_size_source.from_config(config)

def from_config_assert(self, config: Config) -> int | torch.SymInt:
val = self.from_config(config)
assert val is not None
return val

def is_flattened(self, config: Config) -> bool:
return self.block_size_source.is_flattened(config)

def get_order(self, config: Config, count: int) -> list[int]:
return self.block_size_source.get_order(config, count)

def l2_grouping(self, config: Config) -> int:
return self.block_size_source.l2_grouping(config)


class BlockSizeSource:
def from_config(self, config: Config) -> int | torch.SymInt | None:
raise NotImplementedError

def is_flattened(self, config: Config) -> bool:
return False

def get_order(self, config: Config, count: int) -> list[int]:
return [*range(count)]

def l2_grouping(self, config: Config) -> int:
return 1


@dataclasses.dataclass
class FixedBlockSizeSource(BlockSizeSource):
value: int | torch.SymInt

def from_config(self, config: Config) -> int | torch.SymInt:
return self.value


@dataclasses.dataclass
class LoopSpecBlockSizeSource(BlockSizeSource):
loop_spec: int
dim: int

def from_config(self, config: Config) -> int:
value = config.block_sizes[self.loop_spec]
if isinstance(value, int):
assert self.dim == 0
return value
return value[self.dim]

def is_flattened(self, config: Config) -> bool:
return isinstance(config.block_sizes[self.loop_spec], int)

def get_order(self, config: Config, count: int) -> list[int]:
env = CompileEnvironment.current()
spec = env.config_spec.block_size_specs[self.loop_spec]
if not spec.allow_reorder:
return super().get_order(config, count)
assert len(spec) == count
order_offset = sum(
[
int(s.allow_reorder)
for s in env.config_spec.block_size_specs[: self.loop_spec]
]
)
order = config.loop_orders[order_offset]
assert len(order) == count
return order

def l2_grouping(self, config: Config) -> int:
spec = CompileEnvironment.current().config_spec.block_size_specs[self.loop_spec]
if spec.allow_l2_grouping:
return config.l2_grouping
return 1


@dataclasses.dataclass
class ReductionLoopBlockSizeSource(BlockSizeSource):
reduction_loop: int

def from_config(self, config: Config) -> int | None:
return config.reduction_loops[self.reduction_loop]


def warning(warning: exc.BaseWarning | type[exc.BaseWarning]) -> None:
CompileEnvironment.current().errors.add(warning)
Expand Down
9 changes: 8 additions & 1 deletion helion/_compiler/device_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
import dataclasses
import itertools
import math
import threading
from typing import TYPE_CHECKING
from typing import Protocol
Expand Down Expand Up @@ -149,13 +150,17 @@ def __init__(self, name: str, config: Config) -> None:
self.namespace._used_names.update(reserved_names())
self._variable_renames: dict[str, list[str]] = {}
self.dce_vars: list[str] = []
self.block_size_var_cache: dict[tuple[int, ...], str] = {}

from .indexing_strategy import IndexingStrategy
from .tile_dispatch import TileStrategyDispatch

self.tile_strategy: TileStrategyDispatch = TileStrategyDispatch(self, config)
self.indexing_strategy: IndexingStrategy = IndexingStrategy.select(config)

def block_size_var(self, block_size_idx: int) -> str | None:
return self.block_size_var_cache.get((block_size_idx,))

def merge_variable_names(self, a: str, b: str) -> None:
name_group = [
*self._variable_renames.get(a, [a]),
Expand Down Expand Up @@ -184,7 +189,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str:
)
replacements[sym] = sympy.Symbol(arg.name, integer=True)
elif isinstance(origin.origin, BlockSizeOrigin):
result = self.tile_strategy.block_size_var(origin.origin.block_size_idx)
result = self.block_size_var(origin.origin.block_size_idx)
assert result is not None
replacements[sym] = sympy.Symbol(result, integer=True)
else:
Expand All @@ -210,6 +215,8 @@ def literal_expr(self, expr: object) -> str:
return self.sympy_expr(expr._sympy_())
if isinstance(expr, sympy.Expr):
return self.sympy_expr(expr)
if isinstance(expr, float) and not math.isfinite(expr):
return f"float('{expr}')"
return repr(expr)

def unique_name(self, prefix: str, dce: bool = False) -> str:
Expand Down
34 changes: 32 additions & 2 deletions helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(self) -> None:
self.graphs: list[GraphInfo] = []
self.root_id: int | None = None
self.rolled_reductions: list[RolledReductionInfo] = []
self.grid_block_indices: list[list[int]] = []

def get_root(self, config: Config) -> torch.fx.GraphModule:
""" " If we are using a rolled reduction, return the rolled reduction graph otherwise
Expand Down Expand Up @@ -399,6 +400,17 @@ def disable_tracing() -> Iterator[proxy_tensor.PythonKeyTracer]:
with proxy_tensor.disable_proxy_modes_tracing():
yield tracer

@staticmethod
def should_become_arg(value: object) -> bool:
if isinstance(value, (TileIndexProxy, torch.SymInt)):
return False
if isinstance(value, torch.Tensor):
if (
origin := HostFunction.current().tensor_to_origin.get(value)
) is not None:
return origin.is_device()
return True

def visit_For(self, node: ast.For) -> None:
assert isinstance(node, ExtendedAST)
assert not node.orelse
Expand All @@ -412,7 +424,11 @@ def visit_For(self, node: ast.For) -> None:
elif node._loop_type == LoopType.DEVICE:
rw: ReadWrites = ReadWrites.from_ast(node)
inputs: LiftTensorArgs = LiftTensorArgs(
{k: self.scope[k] for k in rw if k in self.scope}
{
k: self.scope[k]
for k in rw
if k in self.scope and self.should_become_arg(self.scope[k])
}
)
outputs: LiftTensorArgs | None = None

Expand Down Expand Up @@ -464,6 +480,8 @@ def run_subgraph(*args: object) -> list[object]:
tracer=tracer,
)
for name, value in outputs.unflatten().items():
if isinstance(value, TileIndexProxy):
continue
if name in self.scope:
try:
self.scope[name] = _tracing_ops._phi(self.scope[name], value)
Expand All @@ -490,7 +508,11 @@ def visit_If(self, node: ast.If) -> object:
def _create_if_subgraph(self, test_proxy: object, body: list[ast.stmt]) -> None:
rw: ReadWrites = ReadWrites.from_list(body)
inputs: LiftTensorArgs = LiftTensorArgs(
{k: self.scope[k] for k in rw if k in self.scope}
{
k: self.scope[k]
for k in rw
if k in self.scope and self.should_become_arg(self.scope[k])
}
)
outputs: LiftTensorArgs | None = None

Expand Down Expand Up @@ -703,6 +725,14 @@ def visit_For(self, node: ast.For) -> None:
self.device_ir.add_root_graph(
_make_fx(lambda: WalkDeviceAST(self.device_ir).visit(node))
)
iter_type = node.iter._type_info
assert isinstance(iter_type, IterType)
inner = iter_type.inner
if isinstance(inner, SequenceType):
block_indices = [x.block_size_idx for x in inner.unpack()]
else:
block_indices = [inner.block_size_idx]
self.device_ir.grid_block_indices.append(block_indices)
else:
self.generic_visit(node)

Expand Down
Loading
Loading