Skip to content

Refactor autotuning logging #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 13 additions & 41 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import dataclasses
import functools
from itertools import starmap
import logging
import math
from math import inf
from multiprocessing import connection
import re
import sys
import time
from typing import TYPE_CHECKING
from typing import NamedTuple
Expand All @@ -20,12 +20,11 @@

from .. import exc
from ..runtime.precompile_shim import already_compiled
from ..runtime.settings import LogLevel
from .config_generation import ConfigGeneration
from .config_generation import FlatConfig
from .logger import LambdaLogger

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Sequence

from ..runtime.config import Config
Expand Down Expand Up @@ -69,6 +68,7 @@ def __init__(self, kernel: BoundKernel, args: Sequence[object]) -> None:
self.config_spec: ConfigSpec = kernel.config_spec
self.args = args
self.counters: collections.Counter[str] = collections.Counter()
self.log = LambdaLogger(self.settings.autotune_log_level)

def benchmark(self, config: Config) -> float:
"""
Expand Down Expand Up @@ -96,7 +96,7 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
:return: The performance of the configuration in seconds.
"""
self.counters["benchmark"] += 1
self.log(lambda: f"Running benchmark for {config!r}", level=LogLevel.DEBUG)
self.log.debug(lambda: f"Running benchmark for {config!r}")
try:
# TODO(jansel): early exit with fewer trials if early runs are slow
t0 = time.perf_counter()
Expand All @@ -107,19 +107,16 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
return_mode="median",
)
t2 = time.perf_counter()
self.log(
f"result: {res:.4f}s (took {t1 - t0:.1f}s + {t2 - t1:.1f}s)",
level=LogLevel.DEBUG,
self.log.debug(
lambda: f"result: {res:.4f}s (took {t1 - t0:.1f}s + {t2 - t1:.1f}s)",
)
return res
except OutOfResources:
self.log("Benchmarking failed: OutOfResources", level=LogLevel.DEBUG)
self.log.debug("Benchmarking failed: OutOfResources")
except Exception as e:
if not _expected_errors_regexp.search(str(e)):
raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
self.log(
f"Benchmarking failed: {type(e).__name__}: {e}", level=LogLevel.DEBUG
)
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
return inf

def start_precompile_and_check_for_hangs(
Expand Down Expand Up @@ -181,18 +178,6 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
results.append((config, inf))
return results

def log(self, *msg: str | Callable[[], str], level: int = LogLevel.INFO) -> None:
"""
Log a message at a specified log level.

:param msg: The message(s) to log. Can be strings or callables that return strings.
:type msg: str | Callable[[], str]
:param level: The log level for the message.
:type level: int
"""
if self.settings.autotune_log_level >= level:
sys.stderr.write(" ".join(map(_maybe_call, msg)) + "\n")

def autotune(self) -> Config:
"""
Perform autotuning to find the best configuration.
Expand All @@ -203,13 +188,14 @@ def autotune(self) -> Config:
:rtype: Config
"""
start = time.perf_counter()
self.log.reset()
best = self._autotune()
end = time.perf_counter()
self.log(
f"Autotuning complete in {end - start:.1f}s after searching {self.counters['benchmark']} configs.\n"
"One can hardcode the best config with and skip autotuning with:\n"
f" @helion.kernel(config={best!r})\n",
level=LogLevel.SUMMARY,
level=logging.INFO + 5,
)
return best

Expand All @@ -224,20 +210,6 @@ def _autotune(self) -> Config:
raise NotImplementedError


def _maybe_call(fn: Callable[[], str] | str) -> str:
"""
Call a callable or return the string directly.

:param fn: A callable that returns a string or a string.
:type fn: Callable[[], str] | str
:return: The resulting string.
:rtype: str
"""
if callable(fn):
return fn()
return fn


class PopulationMember(NamedTuple):
"""
Represents a member of the population in population-based search algorithms.
Expand Down Expand Up @@ -499,14 +471,14 @@ def _mark_complete(self) -> bool:
process.join(10)
msg = f"Timeout after {self.elapsed:.0f}s compiling {self.config}"
if process.is_alive():
self.search.log(
self.search.log.warning(
msg,
"(SIGKILL required)",
level=LogLevel.WARNING,
)
process.kill()
process.join()
else:
self.search.log(msg, level=LogLevel.WARNING)
self.search.log.warning(msg)

self.ok = False
return False
12 changes: 1 addition & 11 deletions helion/autotuner/differential_evolution.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

from math import inf
import random
import time
from typing import TYPE_CHECKING

from .base_search import FlatConfig
Expand Down Expand Up @@ -41,10 +39,6 @@ def __init__(
self.num_generations = num_generations
self.crossover_rate = crossover_rate
self.immediate_update = immediate_update
self.start_time: float = -inf

def timestamp(self) -> str:
return f"[{time.perf_counter() - self.start_time:.0f}s]"

def mutate(self, x_index: int) -> FlatConfig:
a, b, c, *_ = [
Expand All @@ -69,7 +63,6 @@ def initial_two_generations(self) -> None:
key=performance,
)
self.log(
self.timestamp,
"Initial population:",
lambda: population_statistics(oversized_population),
)
Expand All @@ -96,7 +89,6 @@ def evolve_population(self) -> int:
return replaced

def _autotune(self) -> Config:
self.start_time = time.perf_counter()
self.log(
lambda: (
f"Starting DifferentialEvolutionSearch with population={self.population_size}, "
Expand All @@ -106,7 +98,5 @@ def _autotune(self) -> Config:
self.initial_two_generations()
for i in range(2, self.num_generations):
replaced = self.evolve_population()
self.log(
self.timestamp, f"Generation {i}: replaced={replaced}", self.statistics
)
self.log(f"Generation {i}: replaced={replaced}", self.statistics)
return self.best.config
83 changes: 83 additions & 0 deletions helion/autotuner/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

import itertools
import logging
import sys
import time
from typing import Callable


class LambdaLogger:
"""
A self-contained logger that does not propagates to the root logger and
prints each record to stderr in the form:

[<elapsed>s] <message>

where *elapsed* is the whole-second wall-clock time since the logger
instance was created.

Takes lambas as arguments, which are called when the log is emitted.
"""

_count: itertools.count[int] = itertools.count()

def __init__(self, level: int) -> None:
self.level = level
self._logger: logging.Logger = logging.getLogger(
f"{__name__}.{next(self._count)}"
)
self._logger.setLevel(level)
self._logger.propagate = False
self.reset()

def reset(self) -> None:
self._logger.handlers.clear()
self._logger.addHandler(_make_handler())

def __call__(
self, *msg: str | Callable[[], str], level: int = logging.INFO
) -> None:
"""
Log a message at a specified log level.

:param msg: The message(s) to log. Can be strings or callables that return strings.
:type msg: str | Callable[[], str]
:param level: The log level for the message.
:type level: int
"""
if level >= self.level:
self._logger.log(level, " ".join(map(_maybe_call, msg)))

def warning(self, *msg: str | Callable[[], str]) -> None:
return self(*msg, level=logging.WARNING)

def debug(self, *msg: str | Callable[[], str]) -> None:
return self(*msg, level=logging.DEBUG)


def _make_handler() -> logging.Handler:
start = time.perf_counter()

class _ElapsedFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str: # type: ignore[override]
elapsed = int(time.perf_counter() - start)
return f"[{elapsed}s] {record.getMessage()}"

handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(_ElapsedFormatter())
return handler


def _maybe_call(fn: Callable[[], str] | str) -> str:
"""
Call a callable or return the string directly.

:param fn: A callable that returns a string or a string.
:type fn: Callable[[], str] | str
:return: The resulting string.
:rtype: str
"""
if callable(fn):
return fn()
return fn
22 changes: 2 additions & 20 deletions helion/runtime/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dataclasses
import logging
import sys
import threading
from typing import TYPE_CHECKING
Expand All @@ -22,25 +23,6 @@ class _TLS(Protocol):
_tls: _TLS = cast("_TLS", threading.local())


class LogLevel:
"""
Enumeration for log levels used in the search algorithms.

Attributes:
OFF (0): No logging.
SUMMARY (10): Log summary information.
WARNING (20): Log warnings.
INFO (30): Log informational messages.
DEBUG (40): Log detailed debug messages.
"""

OFF = 0
SUMMARY = 10
WARNING = 20
INFO = 30
DEBUG = 40


def set_default_settings(settings: Settings) -> AbstractContextManager[None, None]:
"""
Set the default settings for the current thread and return a context manager
Expand Down Expand Up @@ -72,7 +54,7 @@ class _Settings:
dot_precision: Literal["tf32", "tf32x3", "ieee"] = "tf32"
static_shapes: bool = False
use_default_config: bool = False
autotune_log_level: int = LogLevel.INFO
autotune_log_level: int = logging.INFO
autotune_compile_timeout: int = 60
autotune_precompile: bool = sys.platform != "win32"

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ select = [
"TD004", "TRY002", "TRY203", "TRY401", "UP", "W", "YTT",
]
ignore = [
"C409", "C419", "COM812", "E501", "ERA001", "FURB189", "PERF203", "PT009",
"SIM102", "SIM108", "SIM115", "UP038", "UP035",
"C409", "C419", "COM812", "E501", "ERA001", "FURB189", "G004", "PERF203", "PT009",
"SIM102", "SIM108", "SIM115", "UP035", "UP038",
]
extend-safe-fixes = ["TC", "UP045", "RUF013"]
preview = true
Expand Down