Skip to content

Commit 90b367d

Browse files
authored
🔧 Add typing of rule functions (#283)
Rule functions signature is specific to the state it acts on.
1 parent 64965cf commit 90b367d

File tree

8 files changed

+65
-48
lines changed

8 files changed

+65
-48
lines changed

docs/conf.py

-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
".*Literal.*",
5454
".*_Result",
5555
"EnvType",
56-
"RuleFunc",
5756
"Path",
5857
"Ellipsis",
5958
)

markdown_it/parser_block.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from __future__ import annotations
33

44
import logging
5-
from typing import TYPE_CHECKING, Any
5+
from typing import TYPE_CHECKING, Callable
66

77
from . import rules_block
88
from .ruler import Ruler
@@ -16,7 +16,13 @@
1616
LOGGER = logging.getLogger(__name__)
1717

1818

19-
_rules: list[tuple[str, Any, list[str]]] = [
19+
RuleFuncBlockType = Callable[[StateBlock, int, int, bool], bool]
20+
"""(state: StateBlock, startLine: int, endLine: int, silent: bool) -> matched: bool)
21+
22+
`silent` disables token generation, useful for lookahead.
23+
"""
24+
25+
_rules: list[tuple[str, RuleFuncBlockType, list[str]]] = [
2026
# First 2 params - rule name & source. Secondary array - list of rules,
2127
# which can be terminated by this one.
2228
("table", rules_block.table, ["paragraph", "reference"]),
@@ -45,7 +51,7 @@ class ParserBlock:
4551
"""
4652

4753
def __init__(self) -> None:
48-
self.ruler = Ruler()
54+
self.ruler = Ruler[RuleFuncBlockType]()
4955
for name, rule, alt in _rules:
5056
self.ruler.push(name, rule, {"alt": alt})
5157

markdown_it/parser_core.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
"""
77
from __future__ import annotations
88

9-
from .ruler import RuleFunc, Ruler
9+
from typing import Callable
10+
11+
from .ruler import Ruler
1012
from .rules_core import (
1113
block,
1214
inline,
@@ -18,7 +20,9 @@
1820
)
1921
from .rules_core.state_core import StateCore
2022

21-
_rules: list[tuple[str, RuleFunc]] = [
23+
RuleFuncCoreType = Callable[[StateCore], None]
24+
25+
_rules: list[tuple[str, RuleFuncCoreType]] = [
2226
("normalize", normalize),
2327
("block", block),
2428
("inline", inline),
@@ -31,7 +35,7 @@
3135

3236
class ParserCore:
3337
def __init__(self) -> None:
34-
self.ruler = Ruler()
38+
self.ruler = Ruler[RuleFuncCoreType]()
3539
for name, rule in _rules:
3640
self.ruler.push(name, rule)
3741

markdown_it/parser_inline.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,25 @@
22
"""
33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Callable
66

77
from . import rules_inline
8-
from .ruler import RuleFunc, Ruler
8+
from .ruler import Ruler
99
from .rules_inline.state_inline import StateInline
1010
from .token import Token
1111
from .utils import EnvType
1212

1313
if TYPE_CHECKING:
1414
from markdown_it import MarkdownIt
1515

16+
1617
# Parser rules
17-
_rules: list[tuple[str, RuleFunc]] = [
18+
RuleFuncInlineType = Callable[[StateInline, bool], bool]
19+
"""(state: StateInline, silent: bool) -> matched: bool)
20+
21+
`silent` disables token generation, useful for lookahead.
22+
"""
23+
_rules: list[tuple[str, RuleFuncInlineType]] = [
1824
("text", rules_inline.text),
1925
("linkify", rules_inline.linkify),
2026
("newline", rules_inline.newline),
@@ -34,7 +40,8 @@
3440
#
3541
# Don't use this for anything except pairs (plugins working with `balance_pairs`).
3642
#
37-
_rules2: list[tuple[str, RuleFunc]] = [
43+
RuleFuncInline2Type = Callable[[StateInline], None]
44+
_rules2: list[tuple[str, RuleFuncInline2Type]] = [
3845
("balance_pairs", rules_inline.link_pairs),
3946
("strikethrough", rules_inline.strikethrough.postProcess),
4047
("emphasis", rules_inline.emphasis.postProcess),
@@ -46,11 +53,11 @@
4653

4754
class ParserInline:
4855
def __init__(self) -> None:
49-
self.ruler = Ruler()
56+
self.ruler = Ruler[RuleFuncInlineType]()
5057
for name, rule in _rules:
5158
self.ruler.push(name, rule)
5259
# Second ruler used for post-processing (e.g. in emphasis-like rules)
53-
self.ruler2 = Ruler()
60+
self.ruler2 = Ruler[RuleFuncInline2Type]()
5461
for name, rule2 in _rules2:
5562
self.ruler2.push(name, rule2)
5663

markdown_it/ruler.py

+23-22
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ class Ruler
1717
"""
1818
from __future__ import annotations
1919

20-
from collections.abc import Callable, Iterable
20+
from collections.abc import Iterable
2121
from dataclasses import dataclass, field
22-
from typing import TYPE_CHECKING, TypedDict
22+
from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar
2323
import warnings
2424

2525
from markdown_it._compat import DATACLASS_KWARGS
@@ -57,33 +57,30 @@ def srcCharCode(self) -> tuple[int, ...]:
5757
return self._srcCharCode
5858

5959

60-
# The first positional arg is always a subtype of `StateBase`. Other
61-
# arguments may or may not exist, based on the rule's type (block,
62-
# core, inline). Return type is either `None` or `bool` based on the
63-
# rule's type.
64-
RuleFunc = Callable # type: ignore
65-
66-
6760
class RuleOptionsType(TypedDict, total=False):
6861
alt: list[str]
6962

7063

64+
RuleFuncTv = TypeVar("RuleFuncTv")
65+
"""A rule function, whose signature is dependent on the state type."""
66+
67+
7168
@dataclass(**DATACLASS_KWARGS)
72-
class Rule:
69+
class Rule(Generic[RuleFuncTv]):
7370
name: str
7471
enabled: bool
75-
fn: RuleFunc = field(repr=False)
72+
fn: RuleFuncTv = field(repr=False)
7673
alt: list[str]
7774

7875

79-
class Ruler:
76+
class Ruler(Generic[RuleFuncTv]):
8077
def __init__(self) -> None:
8178
# List of added rules.
82-
self.__rules__: list[Rule] = []
79+
self.__rules__: list[Rule[RuleFuncTv]] = []
8380
# Cached rule chains.
8481
# First level - chain name, '' for default.
8582
# Second level - diginal anchor for fast filtering by charcodes.
86-
self.__cache__: dict[str, list[RuleFunc]] | None = None
83+
self.__cache__: dict[str, list[RuleFuncTv]] | None = None
8784

8885
def __find__(self, name: str) -> int:
8986
"""Find rule index by name"""
@@ -112,7 +109,7 @@ def __compile__(self) -> None:
112109
self.__cache__[chain].append(rule.fn)
113110

114111
def at(
115-
self, ruleName: str, fn: RuleFunc, options: RuleOptionsType | None = None
112+
self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
116113
) -> None:
117114
"""Replace rule by name with new function & options.
118115
@@ -133,7 +130,7 @@ def before(
133130
self,
134131
beforeName: str,
135132
ruleName: str,
136-
fn: RuleFunc,
133+
fn: RuleFuncTv,
137134
options: RuleOptionsType | None = None,
138135
) -> None:
139136
"""Add new rule to chain before one with given name.
@@ -148,14 +145,16 @@ def before(
148145
options = options or {}
149146
if index == -1:
150147
raise KeyError(f"Parser rule not found: {beforeName}")
151-
self.__rules__.insert(index, Rule(ruleName, True, fn, options.get("alt", [])))
148+
self.__rules__.insert(
149+
index, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
150+
)
152151
self.__cache__ = None
153152

154153
def after(
155154
self,
156155
afterName: str,
157156
ruleName: str,
158-
fn: RuleFunc,
157+
fn: RuleFuncTv,
159158
options: RuleOptionsType | None = None,
160159
) -> None:
161160
"""Add new rule to chain after one with given name.
@@ -171,12 +170,12 @@ def after(
171170
if index == -1:
172171
raise KeyError(f"Parser rule not found: {afterName}")
173172
self.__rules__.insert(
174-
index + 1, Rule(ruleName, True, fn, options.get("alt", []))
173+
index + 1, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
175174
)
176175
self.__cache__ = None
177176

178177
def push(
179-
self, ruleName: str, fn: RuleFunc, options: RuleOptionsType | None = None
178+
self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
180179
) -> None:
181180
"""Push new rule to the end of chain.
182181
@@ -185,7 +184,9 @@ def push(
185184
:param options: new rule options (not mandatory).
186185
187186
"""
188-
self.__rules__.append(Rule(ruleName, True, fn, (options or {}).get("alt", [])))
187+
self.__rules__.append(
188+
Rule[RuleFuncTv](ruleName, True, fn, (options or {}).get("alt", []))
189+
)
189190
self.__cache__ = None
190191

191192
def enable(
@@ -252,7 +253,7 @@ def disable(
252253
self.__cache__ = None
253254
return result
254255

255-
def getRules(self, chainName: str) -> list[RuleFunc]:
256+
def getRules(self, chainName: str = "") -> list[RuleFuncTv]:
256257
"""Return array of active functions (rules) for given chain name.
257258
It analyzes rules configuration, compiles caches if not exists and returns result.
258259

markdown_it/rules_block/lheading.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# lheading (---, ==)
22
import logging
33

4-
from ..ruler import Ruler
54
from .state_block import StateBlock
65

76
LOGGER = logging.getLogger(__name__)
@@ -12,7 +11,7 @@ def lheading(state: StateBlock, startLine: int, endLine: int, silent: bool) -> b
1211

1312
level = None
1413
nextLine = startLine + 1
15-
ruler: Ruler = state.md.block.ruler
14+
ruler = state.md.block.ruler
1615
terminatorRules = ruler.getRules("paragraph")
1716

1817
if state.is_code_block(startLine):

markdown_it/rules_block/paragraph.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Paragraph."""
22
import logging
33

4-
from ..ruler import Ruler
54
from .state_block import StateBlock
65

76
LOGGER = logging.getLogger(__name__)
@@ -13,7 +12,7 @@ def paragraph(state: StateBlock, startLine: int, endLine: int, silent: bool) ->
1312
)
1413

1514
nextLine = startLine + 1
16-
ruler: Ruler = state.md.block.ruler
15+
ruler = state.md.block.ruler
1716
terminatorRules = ruler.getRules("paragraph")
1817
endLine = state.lineMax
1918

tests/test_api/test_plugin_creation.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,27 @@
66

77
def inline_rule(state, silent):
88
print("plugin called")
9+
return False
910

1011

1112
def test_inline_after(capsys):
12-
def _plugin(_md):
13+
def _plugin(_md: MarkdownIt) -> None:
1314
_md.inline.ruler.after("text", "new_rule", inline_rule)
1415

1516
MarkdownIt().use(_plugin).parse("[")
1617
assert "plugin called" in capsys.readouterr().out
1718

1819

1920
def test_inline_before(capsys):
20-
def _plugin(_md):
21+
def _plugin(_md: MarkdownIt) -> None:
2122
_md.inline.ruler.before("text", "new_rule", inline_rule)
2223

2324
MarkdownIt().use(_plugin).parse("a")
2425
assert "plugin called" in capsys.readouterr().out
2526

2627

2728
def test_inline_at(capsys):
28-
def _plugin(_md):
29+
def _plugin(_md: MarkdownIt) -> None:
2930
_md.inline.ruler.at("text", inline_rule)
3031

3132
MarkdownIt().use(_plugin).parse("a")
@@ -34,26 +35,27 @@ def _plugin(_md):
3435

3536
def block_rule(state, startLine, endLine, silent):
3637
print("plugin called")
38+
return False
3739

3840

3941
def test_block_after(capsys):
40-
def _plugin(_md):
42+
def _plugin(_md: MarkdownIt) -> None:
4143
_md.block.ruler.after("hr", "new_rule", block_rule)
4244

4345
MarkdownIt().use(_plugin).parse("a")
4446
assert "plugin called" in capsys.readouterr().out
4547

4648

4749
def test_block_before(capsys):
48-
def _plugin(_md):
50+
def _plugin(_md: MarkdownIt) -> None:
4951
_md.block.ruler.before("hr", "new_rule", block_rule)
5052

5153
MarkdownIt().use(_plugin).parse("a")
5254
assert "plugin called" in capsys.readouterr().out
5355

5456

5557
def test_block_at(capsys):
56-
def _plugin(_md):
58+
def _plugin(_md: MarkdownIt) -> None:
5759
_md.block.ruler.at("hr", block_rule)
5860

5961
MarkdownIt().use(_plugin).parse("a")
@@ -65,23 +67,23 @@ def core_rule(state):
6567

6668

6769
def test_core_after(capsys):
68-
def _plugin(_md):
70+
def _plugin(_md: MarkdownIt) -> None:
6971
_md.core.ruler.after("normalize", "new_rule", core_rule)
7072

7173
MarkdownIt().use(_plugin).parse("a")
7274
assert "plugin called" in capsys.readouterr().out
7375

7476

7577
def test_core_before(capsys):
76-
def _plugin(_md):
78+
def _plugin(_md: MarkdownIt) -> None:
7779
_md.core.ruler.before("normalize", "new_rule", core_rule)
7880

7981
MarkdownIt().use(_plugin).parse("a")
8082
assert "plugin called" in capsys.readouterr().out
8183

8284

8385
def test_core_at(capsys):
84-
def _plugin(_md):
86+
def _plugin(_md: MarkdownIt) -> None:
8587
_md.core.ruler.at("normalize", core_rule)
8688

8789
MarkdownIt().use(_plugin).parse("a")

0 commit comments

Comments
 (0)