@@ -41,37 +41,37 @@ def __init__(
41
41
) -> None :
42
42
super ().__init__ ()
43
43
self .strategies : list [TileStrategy ] = []
44
- self .block_indices_to_strategy : dict [tuple [int , ...], TileStrategy ] = {}
44
+ self .block_id_to_strategy : dict [tuple [int , ...], TileStrategy ] = {}
45
45
self ._add_loop_strategies (fn , config )
46
46
self ._add_reduction_strategies (fn , config )
47
47
48
48
def _add_loop_strategies (self , fn : DeviceFunction , config : Config ) -> None :
49
49
device_ir = HostFunction .current ().device_ir
50
- for block_indices in device_ir .grid_block_indices :
51
- self ._add_loop_strategy (block_indices , fn , config )
50
+ for block_ids in device_ir .grid_block_ids :
51
+ self ._add_loop_strategy (block_ids , fn , config )
52
52
for graph in device_ir .graphs :
53
53
if isinstance (graph , ForLoopGraphInfo ) and not isinstance (
54
54
graph , ReductionLoopGraphInfo
55
55
):
56
- block_indices = [* graph .block_indices ]
57
- self ._add_loop_strategy (block_indices , fn , config )
56
+ block_ids = [* graph .block_ids ]
57
+ self ._add_loop_strategy (block_ids , fn , config )
58
58
59
59
def _add_loop_strategy (
60
- self , block_indices : list [int ], fn : DeviceFunction , config : Config
60
+ self , block_ids : list [int ], fn : DeviceFunction , config : Config
61
61
) -> None :
62
62
env = CompileEnvironment .current ()
63
- block_size_infos = [env .block_sizes [i ] for i in block_indices ]
63
+ block_size_infos = [env .block_sizes [i ] for i in block_ids ]
64
64
loop_order = env .config_spec .loop_orders .config_get (
65
- config .loop_orders , block_indices [0 ]
66
- ) or [* range (len (block_indices ))]
65
+ config .loop_orders , block_ids [0 ]
66
+ ) or [* range (len (block_ids ))]
67
67
l2_grouping = env .config_spec .l2_groupings .config_get (
68
- config .l2_groupings , block_indices [0 ], 1
68
+ config .l2_groupings , block_ids [0 ], 1
69
69
)
70
70
71
71
if block_size_infos [0 ].is_grid ():
72
72
strategy : TileStrategy = NDGridTileStrategy (
73
73
fn ,
74
- block_indices ,
74
+ block_ids ,
75
75
loop_order = loop_order ,
76
76
)
77
77
elif block_size_infos [0 ].is_flattened (config ):
@@ -80,20 +80,20 @@ def _add_loop_strategy(
80
80
)
81
81
strategy : TileStrategy = FlattenedTileStrategy (
82
82
fn ,
83
- block_indices ,
83
+ block_ids ,
84
84
block_size = block_size ,
85
85
loop_order = loop_order ,
86
86
)
87
87
else :
88
88
strategy = NDTileStrategy (
89
89
fn ,
90
- block_indices ,
90
+ block_ids ,
91
91
block_size = [bs .from_config_assert (config ) for bs in block_size_infos ],
92
92
loop_order = loop_order ,
93
93
l2_grouping = l2_grouping ,
94
94
)
95
95
self .strategies .append (strategy )
96
- self .block_indices_to_strategy [tuple (block_indices )] = strategy
96
+ self .block_id_to_strategy [tuple (block_ids )] = strategy
97
97
98
98
def _add_reduction_strategies (self , fn : DeviceFunction , config : Config ) -> None :
99
99
env = CompileEnvironment .current ()
@@ -107,20 +107,20 @@ def _add_reduction_strategies(self, fn: DeviceFunction, config: Config) -> None:
107
107
else :
108
108
strategy = LoopedReductionStrategy (fn , block_id , reduction_loop )
109
109
self .strategies .append (strategy )
110
- self .block_indices_to_strategy [(block_id ,)] = strategy
110
+ self .block_id_to_strategy [(block_id ,)] = strategy
111
111
112
- def codegen_grid (self , state : CodegenState , block_indices : list [int ]) -> None :
113
- strategy = self .block_indices_to_strategy [tuple (block_indices )]
112
+ def codegen_grid (self , state : CodegenState , block_ids : list [int ]) -> None :
113
+ strategy = self .block_id_to_strategy [tuple (block_ids )]
114
114
strategy .codegen_grid (state )
115
115
for other_strategy in self .strategies :
116
116
if other_strategy is not strategy :
117
117
other_strategy .codegen_preamble (state )
118
118
state .codegen .set_active_loops (DeviceGridState (strategy ))
119
119
120
120
def codegen_device_loop (
121
- self , state : CodegenState , block_indices : list [int ]
121
+ self , state : CodegenState , block_ids : list [int ]
122
122
) -> DeviceLoopState :
123
- strategy = self .block_indices_to_strategy [tuple (block_indices )]
123
+ strategy = self .block_id_to_strategy [tuple (block_ids )]
124
124
return strategy .codegen_device_loop (state )
125
125
126
126
def _compact_shape (self , shapes : ShapeLike ) -> list [CompactedShape ]:
@@ -161,14 +161,14 @@ def expand_str(self, shape: ShapeLike, i: int) -> str:
161
161
return f"[{ ', ' .join (result )} ]"
162
162
163
163
def get_reduction_strategy (self , block_idx : int ) -> ReductionStrategy :
164
- strategy = self .block_indices_to_strategy [(block_idx ,)]
164
+ strategy = self .block_id_to_strategy [(block_idx ,)]
165
165
assert isinstance (strategy , ReductionStrategy )
166
166
return strategy
167
167
168
168
def user_size (self , block_index : int ) -> sympy .Expr :
169
169
"""The user-visible size of the block index."""
170
170
# This only does something special for reduction loops, only need to check for 1D loop
171
- strategy = self .block_indices_to_strategy .get ((block_index ,))
171
+ strategy = self .block_id_to_strategy .get ((block_index ,))
172
172
if strategy is None :
173
173
return CompileEnvironment .current ().block_sizes [block_index ].symbol ()
174
174
return strategy .user_size (block_index )
0 commit comments