Skip to content

Fixed rule decorator factory typing to help call-by-name call sites #21987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/notes/2.26.x.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ The default version of the [Ruff](https://docs.astral.sh/ruff/) tool has been up

The Pants repo now uses Ruff format in lieu of Black. This was not a drop-in replacement, with over 160 files modified (and about 5 MyPy errors introduced by Ruff's formatting).

`@rule` decorators have been re-typed, which should allow better call site return-type visibility (fewer `Unknown`s and `Any`s). Decorator factories of the form `@rule(desc=..., level=..., ...)` have also been strongly typed. This may cause typechecking errors for plugin authors, if the plugin is using incorrect types. However, this likely would have manifested as a runtime crash, otherwise.

#### Shell

The `experiemental_test_shell_command` target type may now be used with the `test` goal's `--debug` flag to execute the test interactively.
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ warn_unreachable = true
pretty = true
show_column_numbers = true
show_error_context = true
show_error_codes = true
show_traceback = true

[[tool.mypy.overrides]]
Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/backend/project_info/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def get_paths_between_root_and_destination(pair: RootDestinationPair) -> S
return SpecsPaths(paths=spec_paths)


@rule("Get paths between root and multiple destinations.")
@rule(desc="Get paths between root and multiple destinations.")
async def get_paths_between_root_and_destinations(
pair: RootDestinationsPair,
) -> SpecsPathsCollection:
Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/backend/python/lint/flake8/subsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __bool__(self) -> bool:
return self.sources_digest != EMPTY_DIGEST


@rule("Prepare [flake8].source_plugins", level=LogLevel.DEBUG)
@rule(desc="Prepare [flake8].source_plugins", level=LogLevel.DEBUG)
async def flake8_first_party_plugins(flake8: Flake8) -> Flake8FirstPartyPlugins:
if not flake8.source_plugins:
return Flake8FirstPartyPlugins(FrozenOrderedSet(), FrozenOrderedSet(), EMPTY_DIGEST)
Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/backend/python/lint/pylint/subsystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __bool__(self) -> bool:
return self.sources_digest != EMPTY_DIGEST


@rule("Prepare [pylint].source_plugins", level=LogLevel.DEBUG)
@rule(desc="Prepare [pylint].source_plugins", level=LogLevel.DEBUG)
async def pylint_first_party_plugins(pylint: Pylint) -> PylintFirstPartyPlugins:
if not pylint.source_plugins:
return PylintFirstPartyPlugins(FrozenOrderedSet(), FrozenOrderedSet(), EMPTY_DIGEST)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class MyPyFirstPartyPlugins:
source_roots: tuple[str, ...]


@rule("Prepare [mypy].source_plugins", level=LogLevel.DEBUG)
@rule(desc="Prepare [mypy].source_plugins", level=LogLevel.DEBUG)
async def mypy_first_party_plugins(
mypy: MyPy,
) -> MyPyFirstPartyPlugins:
Expand Down
4 changes: 2 additions & 2 deletions src/python/pants/engine/internals/build_files_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
from collections.abc import Mapping
from textwrap import dedent
from typing import Any, cast
from typing import Any

import pytest

Expand Down Expand Up @@ -238,7 +238,7 @@ def run_prelude_parsing_rule(prelude_content: str) -> BuildFilePreludeSymbols:
),
],
)
return cast(BuildFilePreludeSymbols, symbols)
return symbols


def test_prelude_parsing_good() -> None:
Expand Down
84 changes: 73 additions & 11 deletions src/python/pants/engine/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,17 @@
from dataclasses import dataclass
from enum import Enum
from types import FrameType, ModuleType
from typing import Any, Protocol, TypeVar, Union, cast, get_type_hints, overload
from typing import (
Any,
NotRequired,
Protocol,
TypedDict,
TypeVar,
Unpack,
cast,
get_type_hints,
overload,
)

from typing_extensions import ParamSpec

Expand Down Expand Up @@ -49,7 +59,7 @@ class RuleType(Enum):
R = TypeVar("R")
SyncRuleT = Callable[P, R]
AsyncRuleT = Callable[P, Coroutine[Any, Any, R]]
RuleDecorator = Callable[[Union[SyncRuleT, AsyncRuleT]], AsyncRuleT]
RuleDecorator = Callable[[SyncRuleT | AsyncRuleT], AsyncRuleT]


def _rule_call_trampoline(
Expand Down Expand Up @@ -188,7 +198,39 @@ def _ensure_type_annotation(
IMPLICIT_PRIVATE_RULE_DECORATOR_ARGUMENTS = {"rule_type", "cacheable"}


def rule_decorator(func: SyncRuleT | AsyncRuleT, **kwargs) -> AsyncRuleT:
class RuleDecoratorKwargs(TypedDict):
"""Public-facing @rule kwargs used in the codebase."""

canonical_name: NotRequired[str]

canonical_name_suffix: NotRequired[str]

desc: NotRequired[str]
"""The rule's description as it appears in stacktraces/debugging. For goal rules, defaults to the goal name."""

level: NotRequired[LogLevel]
"""The logging level applied to this rule. Defaults to TRACE."""

_masked_types: NotRequired[Iterable[type[Any]]]
"""Unstable. Internal Pants usage only."""

_param_type_overrides: NotRequired[dict[str, type[Any]]]
"""Unstable. Internal Pants usage only."""


class _RuleDecoratorKwargs(RuleDecoratorKwargs):
"""Internal/Implicit @rule kwargs (not for use outside rules.py)"""

rule_type: RuleType
"""The decorator used to declare the rule (see rules.py:_make_rule(...))"""

cacheable: bool
"""Whether the results of this rule should be cached. Typically true for rules, false for goal_rules (see rules.py:_make_rule(...))"""


def rule_decorator(
func: SyncRuleT | AsyncRuleT, **kwargs: Unpack[_RuleDecoratorKwargs]
) -> AsyncRuleT:
if not inspect.isfunction(func):
raise ValueError("The @rule decorator expects to be placed on a function.")

Expand All @@ -205,8 +247,8 @@ def rule_decorator(func: SyncRuleT | AsyncRuleT, **kwargs) -> AsyncRuleT:
f"`@rule`s and `@goal_rule`s only accept the following keyword arguments: {PUBLIC_RULE_DECORATOR_ARGUMENTS}"
)

rule_type: RuleType = kwargs["rule_type"]
cacheable: bool = kwargs["cacheable"]
rule_type = kwargs["rule_type"]
cacheable = kwargs["cacheable"]
masked_types: tuple[type, ...] = tuple(kwargs.get("_masked_types", ()))
param_type_overrides: dict[str, type] = kwargs.get("_param_type_overrides", {})

Expand Down Expand Up @@ -257,7 +299,7 @@ def rule_decorator(func: SyncRuleT | AsyncRuleT, **kwargs) -> AsyncRuleT:
effective_desc = f"`{return_type.name}` goal"

effective_level = kwargs.get("level", LogLevel.TRACE)
if not isinstance(effective_level, LogLevel):
if not isinstance(effective_level, LogLevel): # type: ignore[unused-ignore]
raise ValueError(
"Expected to receive a value of type LogLevel for the level "
f"argument, but got: {effective_level}"
Expand Down Expand Up @@ -343,18 +385,38 @@ def wrapper(*args):
return wrapper


F = TypeVar("F", bound=Callable[..., Any | Coroutine[Any, Any, Any]])


@overload
def rule(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: ...
def rule(**kwargs: Unpack[RuleDecoratorKwargs]) -> Callable[[F], F]:
"""Handles decorator factories of the form `@rule(foo=..., bar=...)`
https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories.

Note: This needs to be the first rule, otherwise MyPy goes nuts
"""
...


@overload
def rule(func: Callable[P, R]) -> Callable[P, Coroutine[Any, Any, R]]: ...
def rule(_func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def rule(_func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
def rule(_func: AsyncRuleT) -> AsyncRuleT:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this won't work as-is - since those need to be given generics. Ends up being:

def rule(_func: AsyncRuleT[P, R]) -> AsyncRuleT[P, R]:

And earlier on, there is another type alias of:

RuleDecorator = Callable[[SyncRuleT | AsyncRuleT], AsyncRuleT]

# needs to be

RuleDecorator = Callable[[SyncRuleT[P, R] | AsyncRuleT[P, R]], AsyncRuleT[P, R]]

But at usage, it should be
RuleDecorator[P, R]

So, I can do all that - but I feel like we're losing the plot with the Russian dolls of generic aliases - and I'd also bet that's how some of these type errors propagated in the first place. None of the uses of SyncRuleT and AsyncRuleT are correctly typed - which is why pyright loses its mind in that file.

Take the goal_rule overloads further down, which are fully qualified with Callables/Coroutines - they don't have any type errors, except the one usage of the type alias.

Finally, at call-usage, when I want to get information about what the typings are - my options are:

# Typealiases - which hide what happens:
def rule(_func: AsyncRuleT[P@rule, R@rule]) -> AsyncRuleT[P@rule, R@rule]: ... 

# More verbose, but more precise
def rule(_func: (**P@rule) -> Coroutine[Any, Any, R@rule]) -> ((**P@rule) -> Coroutine[Any, Any, R@rule]): ...

Nine of ten times, I'd prefer more concise - but the decorator in this case, is genuinely more complex - so it's nice to actually see what's going on, versus 2-3 indirections to get to the meaning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Please disregard my suggestion then.

Btw: I've never seen @ in type hints before. Is that a thing? Or just for example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not in the type hints, that's in the generated code from pyright - but essentially, if you use ParamSpec, you get access to stuff like P.args, and P.kwargs to use in the decorator wrapper typing, and then @location from where it came.

I read the PEP, but don't recall the terminology they use to describe that.

"""Handles bare @rule decorators on async functions.

Usage of Coroutine[...] (vs Awaitable[...]) is intentional, as `MultiGet`/`concurrently` use
coroutines directly.
"""
...


@overload
def rule(
*args, func: None = None, **kwargs: Any
) -> Callable[[SyncRuleT | AsyncRuleT], AsyncRuleT]: ...
def rule(_func: Callable[P, R]) -> Callable[P, Coroutine[Any, Any, R]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def rule(_func: Callable[P, R]) -> Callable[P, Coroutine[Any, Any, R]]:
def rule(_func: SyncRuleT) -> AsyncRuleT:

"""Handles bare @rule decorators on non-async functions It's debatable whether we should even
have non-async @rule functions, but keeping this to not break the world for plugin authors.

Usage of Coroutine[...] (vs Awaitable[...]) is intentional, as `MultiGet`/`concurrently` use
coroutines directly.
"""
...


def rule(*args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically you can spec this and it won't affect caller typing, but it will itself be internally type-checked.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it, but my plan was to do that as part of the rules internals re-factor. I didn't want to obscure this PR any more than it needed to be (because the internal re-factor will touch like 150 lines).

Expand Down
4 changes: 2 additions & 2 deletions src/python/pants/engine/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def another_named_rule(a: int, b: str) -> bool:
def test_bogus_rules(self) -> None:
with pytest.raises(UnrecognizedRuleArgument):

@rule(bogus_kwarg="TOTALLY BOGUS!!!!!!")
@rule(bogus_kwarg="TOTALLY BOGUS!!!!!!") # type: ignore
def a_named_rule(a: int, b: str) -> bool:
return False

Expand Down Expand Up @@ -1081,6 +1081,6 @@ async def obey_human_orders() -> A:

with pytest.raises(MissingParameterTypeAnnotation, match="must be a type"):

@rule(_param_type_overrides={"param1": "A string"})
@rule(_param_type_overrides={"param1": "A string"}) # type: ignore
async def protect_existence(param1) -> A:
return A()
Loading