13
13
import torch
14
14
from torch ._dynamo .convert_frame import compile_lock
15
15
from torch ._inductor import config as inductor_config
16
+ from torch ._inductor import ir
16
17
from torch ._inductor .codegen .simd import SIMDKernelFeatures
17
18
from torch ._inductor .codegen .simd import constant_repr
18
19
from torch ._inductor .codegen .triton import TritonKernel
43
44
from .ast_extension import expr_from_string
44
45
from .ast_extension import statement_from_string
45
46
from .compile_environment import CompileEnvironment
47
+ from .node_masking import apply_masking
48
+ from .node_masking import cached_masked_value
49
+ from .node_masking import mask_node_inputs
46
50
from .tile_strategy import TileStrategy
47
51
48
52
if TYPE_CHECKING :
@@ -185,7 +189,9 @@ def convert_arg(arg: Node) -> TensorBox:
185
189
)
186
190
),
187
191
)
188
- new_node .meta ["lowering" ] = lowering_cls (buffer , used_input_names )
192
+ new_node .meta ["lowering" ] = lowering = lowering_cls (buffer , used_input_names )
193
+ if isinstance (lowering , ReductionLowering ):
194
+ lowering .add_input_mask (new_node )
189
195
nodes .append (new_node )
190
196
extra_input_names .append (buffer .get_name ())
191
197
@@ -269,6 +275,10 @@ class Lowering:
269
275
def codegen (self , ctx : GraphInterpreter , node : torch .fx .Node ) -> object :
270
276
raise NotImplementedError
271
277
278
+ def get_masked_value (self , node : torch .fx .Node ) -> float | bool | None :
279
+ """Get the masked value for this node."""
280
+ return None
281
+
272
282
273
283
@dataclasses .dataclass
274
284
class InductorLowering (Lowering ):
@@ -361,6 +371,11 @@ def codegen(self, ctx: GraphInterpreter, node: torch.fx.Node) -> object:
361
371
output_name = _unpack_opsvalue (self .buffer .data .inner_fn (indices ))
362
372
return expr_from_string (output_name )
363
373
374
+ def get_masked_value (self , node : torch .fx .Node ) -> float | bool | None :
375
+ """Get the masked value for this node."""
376
+ # TODO(jansel): use valueranges to determine masked value
377
+ return None
378
+
364
379
365
380
@dataclasses .dataclass
366
381
class ReductionLowering (InductorLowering ):
@@ -383,6 +398,25 @@ def __init__(
383
398
assert block_index is not None
384
399
self .block_index : int = block_index
385
400
401
+ @property
402
+ def reduction_type (self ) -> str :
403
+ reduction = self .buffer .data
404
+ assert isinstance (reduction , Reduction )
405
+ return reduction .reduction_type
406
+
407
+ def add_input_mask (self , node : torch .fx .Node ) -> None :
408
+ """Modify the node to apply masking for the reduction if needed."""
409
+ reduction_type = self .reduction_type
410
+ input_dtype = None
411
+ for inp in node .all_input_nodes :
412
+ if isinstance (inp .meta ["val" ], torch .Tensor ):
413
+ input_dtype = inp .meta ["val" ].dtype
414
+ break
415
+ assert input_dtype is not None
416
+ default = ir .Reduction .default_accumulator (reduction_type , input_dtype )
417
+ assert isinstance (default , (float , int , bool ))
418
+ mask_node_inputs (node , default )
419
+
386
420
def codegen (self , ctx : GraphInterpreter , node : torch .fx .Node ) -> object :
387
421
reduction = self .buffer .data
388
422
assert isinstance (reduction , Reduction )
@@ -463,6 +497,11 @@ def normalize_args_kwargs(
463
497
node .args = (* bound .arguments .values (),)
464
498
node .kwargs = {}
465
499
500
+ def get_masked_value (self , node : torch .fx .Node ) -> float | bool | None :
501
+ if self .api_func ._get_masked_value is not None :
502
+ return self .api_func ._get_masked_value (node )
503
+ return None
504
+
466
505
467
506
@dataclasses .dataclass
468
507
class SympyExprLowering (Lowering ):
@@ -471,31 +510,61 @@ class SympyExprLowering(Lowering):
471
510
def codegen (self , ctx : GraphInterpreter , node : torch .fx .Node ) -> object :
472
511
return expr_from_string (ctx .cg .device_function .user_sympy_expr (self .expr ))
473
512
513
+ def get_masked_value (self , node : torch .fx .Node ) -> float | bool | None :
514
+ if isinstance (self .expr , sympy .Integer ):
515
+ return int (self .expr )
516
+ if isinstance (self .expr , sympy .Float ):
517
+ return float (self .expr )
518
+ return None
519
+
474
520
475
521
@dataclasses .dataclass
476
522
class LambdaLowering (Lowering ):
477
523
fn : Callable [..., object ]
524
+ masked_value_fn : Callable [[torch .fx .Node ], float | bool | None ] | None = None
478
525
479
526
def codegen (self , ctx : GraphInterpreter , node : torch .fx .Node ) -> object :
480
527
return self .fn (ctx , node )
481
528
529
+ def get_masked_value (self , node : torch .fx .Node ) -> float | bool | None :
530
+ if self .masked_value_fn is not None :
531
+ return self .masked_value_fn (node )
532
+ return None
533
+
534
+
535
+ def passthrough_masked_value (node : torch .fx .Node ) -> float | bool | None :
536
+ for input_node in node .all_input_nodes :
537
+ if isinstance (input_node .meta ["val" ], torch .Tensor ):
538
+ return cached_masked_value (input_node )
539
+ return None
540
+
482
541
483
542
aten_lowering_dispatch : dict [object , Callable [[torch .fx .Node ], Lowering ]] = {}
484
543
485
544
486
- def default_make_lowering (handler : CodegenHandler , node : torch .fx .Node ) -> Lowering :
487
- return LambdaLowering (handler )
545
+ def default_make_lowering (
546
+ handler : CodegenHandler ,
547
+ node : torch .fx .Node ,
548
+ masked_value_fn : Callable [[torch .fx .Node ], float | bool | None ] | None = None ,
549
+ ) -> Lowering :
550
+ return LambdaLowering (handler , masked_value_fn = masked_value_fn )
488
551
489
552
490
553
def register_lowering (
491
554
fn : object ,
492
555
make_lowering : Callable [
493
556
[CodegenHandler , torch .fx .Node ], Lowering
494
557
] = default_make_lowering ,
558
+ masked_value_fn : Callable [[torch .fx .Node ], float | bool | None ] | None = None ,
495
559
) -> Callable [[CodegenHandler ], CodegenHandler ]:
496
560
def decorator (handler : CodegenHandler ) -> CodegenHandler :
497
561
assert fn not in aten_lowering_dispatch , f"Lowering for { fn } already registered"
498
- aten_lowering_dispatch [fn ] = lambda node : make_lowering (handler , node )
562
+ # pyre-ignore[28]
563
+ aten_lowering_dispatch [fn ] = lambda node : make_lowering (
564
+ handler ,
565
+ node ,
566
+ masked_value_fn = masked_value_fn ,
567
+ )
499
568
return handler
500
569
501
570
return decorator
@@ -521,7 +590,12 @@ def codegen_getitem(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
521
590
522
591
523
592
# pyre-fixme[56]
524
- @register_lowering (torch .ops .aten .full .default )
593
+ @register_lowering (
594
+ torch .ops .aten .full .default ,
595
+ masked_value_fn = lambda n : (
596
+ n .args [1 ] if isinstance (n .args [1 ], (int , float , bool )) else None
597
+ ),
598
+ )
525
599
def codegen_full (ctx : GraphInterpreter , node : torch .fx .Node ) -> object :
526
600
env = CompileEnvironment .current ()
527
601
size , fill_value = map_arg (node .args , lambda n : n .meta ["val" ])
@@ -539,7 +613,9 @@ def codegen_full(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
539
613
540
614
541
615
# pyre-fixme[56]
542
- @register_lowering (torch .ops .aten .unsqueeze .default )
616
+ @register_lowering (
617
+ torch .ops .aten .unsqueeze .default , masked_value_fn = passthrough_masked_value
618
+ )
543
619
def codegen_unsqueeze (ctx : GraphInterpreter , node : torch .fx .Node ) -> object :
544
620
assert not node .kwargs , "getitem kwargs not supported"
545
621
tensor , dim = map_arg (node .args , lambda arg : ctx .env [arg ])
@@ -557,10 +633,14 @@ def codegen_unsqueeze(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
557
633
)
558
634
559
635
560
- @register_lowering (torch .ops .aten .squeeze .dim )
561
- @register_lowering (torch .ops .aten .view .default )
636
+ @register_lowering (torch .ops .aten .squeeze .dim , masked_value_fn = passthrough_masked_value )
637
+ @register_lowering (
638
+ torch .ops .aten .view .default , masked_value_fn = passthrough_masked_value
639
+ )
562
640
# pyre-fixme[56]
563
- @register_lowering (torch .ops .aten .reshape .default )
641
+ @register_lowering (
642
+ torch .ops .aten .reshape .default , masked_value_fn = passthrough_masked_value
643
+ )
564
644
def codegen_view (ctx : GraphInterpreter , node : torch .fx .Node ) -> object :
565
645
assert not node .kwargs , "view kwargs not supported"
566
646
tensor = map_arg (node .args [0 ], lambda arg : ctx .env [arg ])
@@ -572,7 +652,9 @@ def codegen_view(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
572
652
573
653
574
654
# pyre-fixme[56]
575
- @register_lowering (torch .ops .aten .permute .default )
655
+ @register_lowering (
656
+ torch .ops .aten .permute .default , masked_value_fn = passthrough_masked_value
657
+ )
576
658
def codegen_permute (ctx : GraphInterpreter , node : torch .fx .Node ) -> object :
577
659
assert not node .kwargs , "getitem kwargs not supported"
578
660
tensor , dims = map_arg (node .args , lambda arg : ctx .env [arg ])
@@ -586,7 +668,9 @@ def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
586
668
587
669
588
670
# pyre-fixme[56]
589
- @register_lowering (torch .ops .aten .expand .default )
671
+ @register_lowering (
672
+ torch .ops .aten .expand .default , masked_value_fn = passthrough_masked_value
673
+ )
590
674
def codegen_expand (ctx : GraphInterpreter , node : torch .fx .Node ) -> object :
591
675
assert not node .kwargs , "getitem kwargs not supported"
592
676
tensor , _ = map_arg (node .args , lambda arg : ctx .env [arg ])
@@ -606,7 +690,11 @@ def codegen_expand(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
606
690
)
607
691
608
692
609
- def apply_dot_requirements (handler : CodegenHandler , node : torch .fx .Node ) -> Lowering :
693
+ def apply_dot_requirements (
694
+ handler : CodegenHandler ,
695
+ node : torch .fx .Node ,
696
+ masked_value_fn : Callable [[torch .fx .Node ], float | bool | None ] | None = None ,
697
+ ) -> Lowering :
610
698
"""Apply min_dot_size requirements to the config_spec"""
611
699
assert not node .kwargs , "dot kwargs not supported"
612
700
assert len (node .args ) in (2 , 3 )
@@ -625,7 +713,14 @@ def apply_dot_requirements(handler: CodegenHandler, node: torch.fx.Node) -> Lowe
625
713
block_idx = TileStrategy .get_block_index (shape )
626
714
if block_idx is not None :
627
715
env .block_sizes [block_idx ].update_min_block (min_size , allow_flattened = True )
628
- return LambdaLowering (handler )
716
+ # inputs to the dot operation must be zero-masked
717
+ * maybe_acc , lnode , rnode = node .args
718
+ assert isinstance (lnode , torch .fx .Node )
719
+ assert isinstance (rnode , torch .fx .Node )
720
+ lnode = apply_masking (lnode , base_node = node , other = 0 )
721
+ rnode = apply_masking (rnode , base_node = node , other = 0 )
722
+ node .args = (* maybe_acc , lnode , rnode )
723
+ return LambdaLowering (handler , masked_value_fn = masked_value_fn )
629
724
630
725
631
726
@register_lowering (torch .ops .aten .bmm .default , apply_dot_requirements )
0 commit comments