Skip to content

Commit b2306ff

Browse files
oulgenpytorchmergebot
authored andcommitted
[RFC][Codemod] Rename device/host _fn to _function (#62)
There are so places where we use device_fn, host_fn, device_function and host_function. Lets be consistent everywhere. I'm also ok if we want to swap to _fn versions, as long as we are consistent. Pull Request resolved: #62 Approved by: https://github.com/jansel
1 parent 6497bea commit b2306ff

File tree

7 files changed

+35
-31
lines changed

7 files changed

+35
-31
lines changed

helion/_compiler/device_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,13 @@ def tensor_arg(
245245
def tensor_descriptor_arg(
246246
self, fake_value: torch.Tensor, block_size: list[int | torch.SymInt]
247247
) -> TensorArg:
248-
host_fn = HostFunction.current()
248+
host_function = HostFunction.current()
249249
block_size_expr = ", ".join(
250250
map(HostFunction.current().literal_expr, block_size)
251251
)
252252
key = (fake_value, block_size_expr)
253253
if key not in self._tensor_descriptor_args:
254-
origin = host_fn.tensor_to_origin[fake_value]
254+
origin = host_function.tensor_to_origin[fake_value]
255255
arg = TensorDescriptorArg(
256256
self.new_var(origin.suggest_var_name() + "_desc"),
257257
fake_value,

helion/_compiler/generate_ast.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
class GenerateAST(NodeVisitor):
3434
def __init__(self, func: HostFunction, config: Config) -> None:
3535
super().__init__()
36-
self.host_fn = func
36+
self.host_function = func
3737
self.host_statements: list[ast.AST] = []
3838
self.statements_stack: list[list[ast.AST]] = [self.host_statements]
3939
self.on_device = False
@@ -175,7 +175,7 @@ def visit_For(self, node: ast.For) -> ast.AST:
175175
)
176176
codegen_call_with_graph(
177177
self,
178-
self.host_fn.device_ir.get_root(self.device_function.config),
178+
self.host_function.device_ir.get_root(self.device_function.config),
179179
[],
180180
)
181181
self.device_function.dead_code_elimination()
@@ -188,9 +188,11 @@ def visit_Name(self, node: ast.Name) -> ast.AST:
188188
origin = node._type_info.origin
189189
if (
190190
isinstance(origin, ArgumentOrigin)
191-
and origin.name in self.host_fn.constexpr_args
191+
and origin.name in self.host_function.constexpr_args
192192
):
193-
return expr_from_string(repr(self.host_fn.constexpr_args[origin.name]))
193+
return expr_from_string(
194+
repr(self.host_function.constexpr_args[origin.name])
195+
)
194196
if origin.needs_rename():
195197
# `x` => `_original_globals.x`
196198
return expr_from_string(origin.host_str())
@@ -207,7 +209,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
207209
elif isinstance(type_info := node._type_info, TileIndexType):
208210
block_info = env.block_sizes[type_info.block_size_idx]
209211
return expr_from_string(
210-
self.host_fn.literal_expr(
212+
self.host_function.literal_expr(
211213
block_info.from_config(self.device_function.config)
212214
)
213215
)
@@ -216,7 +218,7 @@ def visit_Call(self, node: ast.Call) -> ast.AST:
216218
if all(isinstance(x, TileIndexType) for x in values):
217219
block_infos = [env.block_sizes[x.block_size_idx] for x in values]
218220
return expr_from_string(
219-
self.host_fn.literal_expr(
221+
self.host_function.literal_expr(
220222
[
221223
x.from_config(self.device_function.config)
222224
for x in block_infos
@@ -248,15 +250,15 @@ def has_mask(self) -> bool:
248250

249251

250252
def codegen_precompile_def(
251-
host_def: ast.FunctionDef, device_fn_name: str
253+
host_def: ast.FunctionDef, device_function_name: str
252254
) -> ast.FunctionDef:
253255
"""
254256
Generate a precompile function definition for the given host function.
255257
The precompile function is the same as the normal function, but the call to the
256258
kernel is replaced with a call to make_precompiler.
257259
258260
:param host_def: The host function definition to that is used to call the kernel.
259-
:param device_fn_name: The name of the device function to be called.
261+
:param device_function_name: The name of the device function to be called.
260262
:return: A transformed function definition with the kernel call replaced.
261263
"""
262264

@@ -285,7 +287,7 @@ def transform(node: ExtendedAST) -> ExtendedAST:
285287
ast.Return,
286288
value=value.copy(
287289
func=expr_from_string(
288-
f"make_precompiler({device_fn_name})"
290+
f"make_precompiler({device_function_name})"
289291
)
290292
),
291293
)

helion/_compiler/lift_closures.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ def lift_closures(func: FunctionType, origin: Origin) -> FunctionType:
3636
def wrapper(*args: object, **kwargs: object) -> object:
3737
nonlocal new_func, closure_contents
3838
if new_func is None:
39-
host_fn = HostFunction.current()
39+
host_function = HostFunction.current()
4040
closure = None
4141
if func.__closure__ is not None:
4242
closure_contents = [
43-
host_fn.register_fake(obj.cell_contents, ClosureOrigin(origin, i))
43+
host_function.register_fake(
44+
obj.cell_contents, ClosureOrigin(origin, i)
45+
)
4446
for i, obj in enumerate(func.__closure__)
4547
]
4648
closure = (*map(make_cell, closure_contents),)

helion/_compiler/reduction_strategy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def __init__(
235235
def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
236236
env = CompileEnvironment.current()
237237
block_index = self.block_index
238-
device_fn = state.device_function
238+
device_function = state.device_function
239239
numel = env.block_sizes[block_index].numel
240240
offset_var = self.offset_var(block_index)
241241
index_var = self.index_var(block_index)
@@ -253,14 +253,14 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
253253
if (mask_var := self._mask_var) is not None:
254254
body.append(
255255
statement_from_string(
256-
f"{mask_var} = {index_var} < {device_fn.sympy_expr(numel)}"
256+
f"{mask_var} = {index_var} < {device_function.sympy_expr(numel)}"
257257
)
258258
)
259259
for_node = create(
260260
ast.For,
261261
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
262262
iter=expr_from_string(
263-
f"range(0, ({device_fn.sympy_expr(numel)}), {block_size_var})"
263+
f"range(0, ({device_function.sympy_expr(numel)}), {block_size_var})"
264264
),
265265
body=body,
266266
orelse=[],

helion/_compiler/source_location.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@ def __init__(
6767
def from_ast(node: ast.AST) -> SourceLocation:
6868
from .host_function import HostFunction
6969

70-
host_fn = HostFunction.current()
71-
code = host_fn.fn.__code__
70+
host_function = HostFunction.current()
71+
code = host_function.fn.__code__
7272
offset = code.co_firstlineno - 1
7373
return SourceLocation(
7474
node.lineno + offset,
75-
node.col_offset + host_fn.column_offset,
75+
node.col_offset + host_function.column_offset,
7676
node.end_lineno + offset,
77-
node.end_col_offset + host_fn.column_offset,
77+
node.end_col_offset + host_function.column_offset,
7878
filename=code.co_filename,
7979
name=code.co_name,
8080
)

helion/_compiler/tile_strategy.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def _codegen_common(
211211
block_indices = self.block_indices
212212
env = CompileEnvironment.current()
213213
total_numel = sympy.S.One
214-
device_fn = state.device_function
214+
device_function = state.device_function
215215
offsets_var = self.new_var("offsets", dce=True)
216216
block_size_var = self.block_size_var(-1)
217217
statements = []
@@ -226,17 +226,17 @@ def _codegen_common(
226226
block_index_var = self.index_var(block_idx)
227227
expr = offsets_var
228228
if total_numel != sympy.S.One:
229-
expr = f"({expr}) // ({device_fn.sympy_expr(total_numel)})"
229+
expr = f"({expr}) // ({device_function.sympy_expr(total_numel)})"
230230
if i + 1 < len(block_indices):
231-
expr = f"({expr}) % ({device_fn.sympy_expr(numel)})"
231+
expr = f"({expr}) % ({device_function.sympy_expr(numel)})"
232232
statements.append(statement_from_string(f"{block_index_var} = {expr}"))
233233
total_numel = total_numel * numel
234234

235235
mask_var = self.mask_var(-1)
236236
if mask_var is not None:
237237
statements.append(
238238
statement_from_string(
239-
f"{mask_var} = {offsets_var} < ({device_fn.sympy_expr(total_numel)})"
239+
f"{mask_var} = {offsets_var} < ({device_function.sympy_expr(total_numel)})"
240240
)
241241
)
242242
return block_size_var, offsets_var, total_numel, statements
@@ -375,7 +375,7 @@ def mask_var(self, block_idx: int) -> str | None:
375375
def codegen_grid(self, state: CodegenState) -> None:
376376
block_indices = self.block_indices
377377
env = CompileEnvironment.current()
378-
device_fn = state.device_function
378+
device_function = state.device_function
379379
dtype = env.triton_index_type()
380380
block_sizes = self.block_size
381381
assert len(block_sizes) == len(block_indices)
@@ -386,7 +386,7 @@ def codegen_grid(self, state: CodegenState) -> None:
386386
numel = env.block_sizes[block_idx].numel
387387
offset_var = self.offset_var(block_idx)
388388
index_var = self.index_var(block_idx)
389-
pid_var = device_fn.new_var(f"pid_{i}", dce=True)
389+
pid_var = device_function.new_var(f"pid_{i}", dce=True)
390390
if block_size != 1:
391391
block_size_var = self.block_size_var(block_idx)
392392
assert block_size_var is not None
@@ -444,7 +444,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
444444
# TODO(jansel): refactor this to share code with codegen_grid
445445
block_indices = self.block_indices
446446
env = CompileEnvironment.current()
447-
device_fn = state.device_function
447+
device_function = state.device_function
448448
dtype = env.triton_index_type()
449449
block_sizes = self.block_size
450450
body = innermost_body = []
@@ -471,7 +471,7 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState:
471471
ast.For,
472472
target=create(ast.Name, id=offset_var, ctx=ast.Store()),
473473
iter=expr_from_string(
474-
f"range(0, ({device_fn.sympy_expr(numel)}), {block_size_var})"
474+
f"range(0, ({device_function.sympy_expr(numel)}), {block_size_var})"
475475
),
476476
body=body,
477477
orelse=[],

helion/runtime/kernel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __init__(self, kernel: Kernel, args: tuple[object, ...]) -> None:
251251
else:
252252
self.fake_args.append(self.env.to_fake(arg, ArgumentOrigin(name)))
253253
with _maybe_skip_dtype_check_in_meta_registrations():
254-
self.host_fn: HostFunction = HostFunction(
254+
self.host_function: HostFunction = HostFunction(
255255
self.kernel.fn, self.fake_args, constexpr_args
256256
)
257257
if len(kernel.configs) == 1:
@@ -301,7 +301,7 @@ def to_triton_code(self, config: ConfigLike) -> str:
301301
# pyre-ignore[6]
302302
config = Config(**config)
303303
self.env.config_spec.normalize(config)
304-
root = generate_ast(self.host_fn, config)
304+
root = generate_ast(self.host_function, config)
305305
return get_needed_imports(root) + unparse(root)
306306

307307
def compile_config(self, config: ConfigLike) -> CompiledConfig:
@@ -334,7 +334,7 @@ def _debug_str(self) -> str:
334334
:rtype: str
335335
"""
336336
with self.env:
337-
return self.host_fn.debug_str()
337+
return self.host_function.debug_str()
338338

339339
def autotune(
340340
self,

0 commit comments

Comments
 (0)