Skip to content

Commit 247d8d4

Browse files
committed
[mlir][gpu] Add uniform flag to gpu reduction ops
Differential Revision: https://reviews.llvm.org/D138758
1 parent 55cbda9 commit 247d8d4

File tree

7 files changed

+44
-21
lines changed

7 files changed

+44
-21
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,8 @@ def GPU_AllReduceOperationAttr : EnumAttr<GPU_Dialect, GPU_AllReduceOperation,
688688
def GPU_AllReduceOp : GPU_Op<"all_reduce",
689689
[SameOperandsAndResultType, IsolatedFromAbove]>,
690690
Arguments<(ins AnyType:$value,
691-
OptionalAttr<GPU_AllReduceOperationAttr>:$op)>,
691+
OptionalAttr<GPU_AllReduceOperationAttr>:$op,
692+
UnitAttr:$uniform)>,
692693
Results<(outs AnyType)> {
693694
let summary = "Reduce values among workgroup.";
694695
let description = [{
@@ -711,19 +712,21 @@ def GPU_AllReduceOp : GPU_Op<"all_reduce",
711712
accumulation as code region. The accumulation operation must be one of:
712713
`add`, `and`, `max`, `min`, `mul`, `or`, `xor`.
713714

714-
Either none or all work items of a workgroup need to execute this op
715-
in convergence.
715+
If `uniform` flag is set either none or all work items of a workgroup
716+
need to execute this op in convergence.
716717
}];
717718
let regions = (region AnyRegion:$body);
718-
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value $body attr-dict
719+
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
720+
(`uniform` $uniform^)? $body attr-dict
719721
`:` functional-type(operands, results) }];
720722
let hasRegionVerifier = 1;
721723
}
722724

723725
def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
724726
[SameOperandsAndResultType]>,
725727
Arguments<(ins AnyType:$value,
726-
GPU_AllReduceOperationAttr:$op)>,
728+
GPU_AllReduceOperationAttr:$op,
729+
UnitAttr:$uniform)>,
727730
Results<(outs AnyType)> {
728731
let summary = "Reduce values among subgroup.";
729732
let description = [{
@@ -736,10 +739,11 @@ def GPU_SubgroupReduceOp : GPU_Op<"subgroup_reduce",
736739
%1 = gpu.subgroup_reduce add %0 : (f32) -> (f32)
737740
```
738741

739-
Either none or all work items of a subgroup need to execute this op
740-
in convergence.
742+
If `uniform` flag is set either none or all work items of a subgroup
743+
need to execute this op in convergence.
741744
}];
742-
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value attr-dict
745+
let assemblyFormat = [{ custom<AllReduceOperation>($op) $value
746+
(`uniform` $uniform^)? attr-dict
743747
`:` functional-type(operands, results) }];
744748
let hasVerifier = 1;
745749
}

mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -394,14 +394,23 @@ struct GpuAllReduceConversion : public RewritePattern {
394394
LogicalResult matchAndRewrite(Operation *op,
395395
PatternRewriter &rewriter) const override {
396396
auto funcOp = cast<gpu::GPUFuncOp>(op);
397-
auto callback = [&](gpu::AllReduceOp reduceOp) {
398-
GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
399-
// Performing a rewrite invalidates the walk iterator. Report interrupt
400-
// so that we can start a new walk until all all_reduce ops are replaced.
401-
return WalkResult::interrupt();
397+
398+
SmallVector<gpu::AllReduceOp> reduceOps;
399+
auto callback = [&](gpu::AllReduceOp reduceOp) -> WalkResult {
400+
if (!reduceOp.getUniform())
401+
return WalkResult::interrupt();
402+
403+
reduceOps.emplace_back(reduceOp);
404+
return WalkResult::advance();
402405
};
403-
while (funcOp.walk(callback).wasInterrupted()) {
404-
}
406+
407+
if (funcOp.walk(callback).wasInterrupted())
408+
return rewriter.notifyMatchFailure(
409+
op, "Non uniform reductions are not supported yet.");
410+
411+
for (gpu::AllReduceOp reduceOp : reduceOps)
412+
GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
413+
405414
return success();
406415
}
407416
};

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ gpu.module @test_module {
8989
// CHECK: nvvm.shfl.sync bfly {{.*}}
9090
// CHECK: nvvm.barrier0
9191
// CHECK: llvm.fadd
92-
%result = gpu.all_reduce add %arg0 {} : (f32) -> (f32)
92+
%result = gpu.all_reduce add %arg0 uniform {} : (f32) -> (f32)
9393

9494
gpu.return
9595
}
@@ -104,7 +104,7 @@ gpu.module @test_module {
104104
// TODO: Check full IR expansion once lowering has settled.
105105
// CHECK: nvvm.shfl.sync bfly {{.*}}
106106
// CHECK: nvvm.barrier0
107-
%result = gpu.all_reduce %arg0 {
107+
%result = gpu.all_reduce %arg0 uniform {
108108
^bb(%lhs : i32, %rhs : i32):
109109
%xor = arith.xori %lhs, %rhs : i32
110110
"gpu.yield"(%xor) : (i32) -> ()

mlir/test/Dialect/GPU/all-reduce-max.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ gpu.module @kernels {
195195
// CHECK: cf.br ^bb42
196196
// CHECK: ^bb42:
197197
// CHECK: gpu.barrier
198-
%sum = gpu.all_reduce max %arg0 {} : (f32) -> (f32)
198+
%sum = gpu.all_reduce max %arg0 uniform {} : (f32) -> (f32)
199199
gpu.return
200200
}
201201

mlir/test/Dialect/GPU/all-reduce.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ gpu.module @kernels {
175175
// CHECK: cf.br ^bb42
176176
// CHECK: ^bb42:
177177
// CHECK: gpu.barrier
178-
%sum = gpu.all_reduce add %arg0 {} : (f32) -> (f32)
178+
%sum = gpu.all_reduce add %arg0 uniform {} : (f32) -> (f32)
179179
gpu.return
180180
}
181181

mlir/test/Dialect/GPU/multiple-all-reduce.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ func.func @main() {
1010
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
1111
threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1, %block_z = %c1) {
1212
%val = memref.load %data[%bx, %tx] : memref<2x6xf32>
13-
%reduced0 = gpu.all_reduce add %val {} : (f32) -> (f32)
13+
%reduced0 = gpu.all_reduce add %val uniform {} : (f32) -> (f32)
1414
memref.store %reduced0, %sum[%bx] : memref<2xf32>
15-
%reduced1 = gpu.all_reduce mul %val {} : (f32) -> (f32)
15+
%reduced1 = gpu.all_reduce mul %val uniform {} : (f32) -> (f32)
1616
memref.store %reduced1, %mul[%bx] : memref<2xf32>
1717
gpu.terminator
1818
}

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,21 @@ module attributes {gpu.container_module} {
8383
%SgSi = gpu.subgroup_size : index
8484

8585
%one = arith.constant 1.0 : f32
86+
87+
// CHECK: %{{.*}} = gpu.all_reduce add %{{.*}} {
88+
// CHECK-NEXT: } : (f32) -> f32
8689
%sum = gpu.all_reduce add %one {} : (f32) -> (f32)
8790

91+
// CHECK: %{{.*}} = gpu.all_reduce add %{{.*}} uniform {
92+
// CHECK-NEXT: } : (f32) -> f32
93+
%sum1 = gpu.all_reduce add %one uniform {} : (f32) -> f32
94+
8895
// CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} : (f32) -> f32
8996
%sum_subgroup = gpu.subgroup_reduce add %one : (f32) -> f32
9097

98+
// CHECK: %{{.*}} = gpu.subgroup_reduce add %{{.*}} uniform : (f32) -> f32
99+
%sum_subgroup1 = gpu.subgroup_reduce add %one uniform : (f32) -> f32
100+
91101
%width = arith.constant 7 : i32
92102
%offset = arith.constant 3 : i32
93103
// CHECK: gpu.shuffle xor %{{.*}}, %{{.*}}, %{{.*}} : f32

0 commit comments

Comments
 (0)