Skip to content

Commit 49a05d0

Browse files
committed
Do not print output code durring autotuning
Printing 1500+ triton kernels during autotuning is not useful. stack-info: PR: #130, branch: jansel/stack/26
1 parent e2a0556 commit 49a05d0

File tree

3 files changed

+20
-11
lines changed

3 files changed

+20
-11
lines changed

helion/autotuner/base_search.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from math import inf
1111
from multiprocessing import connection
1212
import re
13+
import sys
1314
import time
1415
from typing import TYPE_CHECKING
1516
from typing import NamedTuple
@@ -82,7 +83,7 @@ def benchmark(self, config: Config) -> float:
8283
:return: The performance of the configuration in seconds.
8384
:rtype: float
8485
"""
85-
fn = self.kernel.compile_config(config)
86+
fn = self.kernel.compile_config(config, allow_print=False)
8687
if self.start_precompile_and_check_for_hangs(config, fn)():
8788
return self.benchmark_function(config, fn)
8889
return inf
@@ -160,7 +161,7 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
160161
:return: A list of tuples containing configurations and their performance.
161162
:rtype: list[tuple[Config, float]]
162163
"""
163-
fns = [self.kernel.compile_config(c) for c in configs]
164+
fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
164165
if self.settings.autotune_precompile:
165166
is_workings = PrecompileFuture.wait_for_all(
166167
[
@@ -200,6 +201,9 @@ def autotune(self) -> Config:
200201
f" @helion.kernel(config={best!r})\n",
201202
level=logging.INFO + 5,
202203
)
204+
if self.settings.print_output_code:
205+
triton_code = self.kernel.to_triton_code(best)
206+
print(triton_code, file=sys.stderr)
203207
return best
204208

205209
def _autotune(self) -> Config:

helion/runtime/kernel.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,12 +328,16 @@ def to_triton_code(self, config: ConfigLike) -> str:
328328
root = generate_ast(self.host_function, config)
329329
return get_needed_imports(root) + unparse(root)
330330

331-
def compile_config(self, config: ConfigLike) -> CompiledConfig:
331+
def compile_config(
332+
self, config: ConfigLike, *, allow_print: bool = True
333+
) -> CompiledConfig:
332334
"""
333335
Compile the kernel for a specific configuration.
334336
335337
:param config: The configuration to compile the kernel with.
336338
:type config: Config or dict[str, object]
339+
:param allow_print: Set to suppress printing the output code when autotuning.
340+
:type allow_print: bool
337341
:return: A callable object representing the compiled kernel.
338342
:rtype: Callable[..., object]
339343
"""
@@ -342,10 +346,11 @@ def compile_config(self, config: ConfigLike) -> CompiledConfig:
342346
if (rv := self._compile_cache.get(config)) is not None:
343347
return rv
344348
triton_code = self.to_triton_code(config)
345-
log.info("Output code: \n%s", triton_code)
346-
log.debug("Debug string: \n%s", LazyString(lambda: self._debug_str()))
347-
if self.settings.print_output_code:
348-
print(triton_code, file=sys.stderr)
349+
if allow_print:
350+
log.info("Output code: \n%s", triton_code)
351+
log.debug("Debug string: \n%s", LazyString(lambda: self._debug_str()))
352+
if self.settings.print_output_code:
353+
print(triton_code, file=sys.stderr)
349354
module = PyCodeCache.load(triton_code)
350355
rv = getattr(module, self.kernel.name)
351356
rv.make_precompiler = getattr(module, f"_{self.kernel.name}_make_precompiler")

helion/runtime/settings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ class _Settings:
5454
index_dtype: torch.dtype = torch.int32
5555
dot_precision: Literal["tf32", "tf32x3", "ieee"] = "tf32"
5656
static_shapes: bool = False
57-
use_default_config: bool = False
57+
use_default_config: bool = os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") == "1"
5858
autotune_log_level: int = logging.INFO
59-
autotune_compile_timeout: int = 60
59+
autotune_compile_timeout: int = int(
60+
os.environ.get("HELION_AUTOTUNE_COMPILE_TIMEOUT", "60")
61+
)
6062
autotune_precompile: bool = sys.platform != "win32"
6163
print_output_code: bool = os.environ.get("HELION_PRINT_OUTPUT_CODE", "0") == "1"
6264

@@ -91,8 +93,6 @@ def __init__(self, **settings: object) -> None:
9193
settings = {**defaults.to_dict(), **settings}
9294
# pyre-ignore[6]
9395
super().__init__(**settings)
94-
if os.getenv("HELION_USE_DEFAULT_CONFIG") == "1":
95-
self.use_default_config: bool = True
9696

9797
def to_dict(self) -> dict[str, object]:
9898
"""

0 commit comments

Comments
 (0)