Skip to content

Commit 4e481f8

Browse files
committed
[NVPTX] support VECREDUCE_SEQ ops and remove option
1 parent aed0f4a commit 4e481f8

File tree

2 files changed

+44
-36
lines changed

2 files changed

+44
-36
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,6 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88-
static cl::opt<bool> DisableFOpTreeReduce(
89-
"nvptx-disable-fop-tree-reduce", cl::Hidden,
90-
cl::desc("NVPTX Specific: don't emit tree reduction for floating-point "
91-
"reduction operations"),
92-
cl::init(false));
93-
9488
static cl::opt<int> UsePrecDivF32(
9589
"nvptx-prec-divf32", cl::Hidden,
9690
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -841,6 +835,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
841835
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
842836
EltVT == MVT::f64) {
843837
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
838+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
844839
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845840
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846841
VT, Custom);
@@ -2204,12 +2199,19 @@ static SDValue BuildTreeReduction(
22042199
/// max3/min3 when the target supports them.
22052200
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22062201
SelectionDAG &DAG) const {
2207-
if (DisableFOpTreeReduce)
2208-
return SDValue();
2209-
22102202
SDLoc DL(Op);
22112203
const SDNodeFlags Flags = Op->getFlags();
2212-
const SDValue &Vector = Op.getOperand(0);
2204+
SDValue Vector;
2205+
SDValue Accumulator;
2206+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2207+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2208+
// special case with accumulator as first arg
2209+
Accumulator = Op.getOperand(0);
2210+
Vector = Op.getOperand(1);
2211+
} else {
2212+
// default case
2213+
Vector = Op.getOperand(0);
2214+
}
22132215
EVT EltTy = Vector.getValueType().getVectorElementType();
22142216
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22152217
STI.getPTXVersion() >= 88;
@@ -2221,10 +2223,12 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22212223

22222224
switch (Op->getOpcode()) {
22232225
case ISD::VECREDUCE_FADD:
2226+
case ISD::VECREDUCE_SEQ_FADD:
22242227
ScalarOps = {{ISD::FADD, 2}};
22252228
IsReassociatable = false;
22262229
break;
22272230
case ISD::VECREDUCE_FMUL:
2231+
case ISD::VECREDUCE_SEQ_FMUL:
22282232
ScalarOps = {{ISD::FMUL, 2}};
22292233
IsReassociatable = false;
22302234
break;
@@ -2303,11 +2307,13 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23032307
}
23042308

23052309
// Lower to tree reduction.
2306-
if (IsReassociatable || Flags.hasAllowReassociation())
2310+
if (IsReassociatable || Flags.hasAllowReassociation()) {
2311+
// we don't expect an accumulator for reassociatable vector reduction ops
2312+
assert(!Accumulator && "unexpected accumulator");
23072313
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2314+
}
23082315

23092316
// Lower to sequential reduction.
2310-
SDValue Accumulator;
23112317
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
23122318
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
23132319
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
@@ -3113,6 +3119,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
31133119
return LowerCONCAT_VECTORS(Op, DAG);
31143120
case ISD::VECREDUCE_FADD:
31153121
case ISD::VECREDUCE_FMUL:
3122+
case ISD::VECREDUCE_SEQ_FADD:
3123+
case ISD::VECREDUCE_SEQ_FMUL:
31163124
case ISD::VECREDUCE_FMAX:
31173125
case ISD::VECREDUCE_FMIN:
31183126
case ISD::VECREDUCE_FMAXIMUM:

llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@ define half @reduce_fadd_half(<8 x half> %in) {
2323
; CHECK-EMPTY:
2424
; CHECK-NEXT: // %bb.0:
2525
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_param_0];
26-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
27-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
28-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
29-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
30-
; CHECK-NEXT: mov.b16 %rs9, 0x0000;
31-
; CHECK-NEXT: add.rn.f16 %rs10, %rs7, %rs9;
32-
; CHECK-NEXT: add.rn.f16 %rs11, %rs10, %rs8;
33-
; CHECK-NEXT: add.rn.f16 %rs12, %rs11, %rs5;
34-
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs6;
35-
; CHECK-NEXT: add.rn.f16 %rs14, %rs13, %rs3;
36-
; CHECK-NEXT: add.rn.f16 %rs15, %rs14, %rs4;
37-
; CHECK-NEXT: add.rn.f16 %rs16, %rs15, %rs1;
38-
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs2;
26+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r1;
27+
; CHECK-NEXT: mov.b16 %rs3, 0x0000;
28+
; CHECK-NEXT: add.rn.f16 %rs4, %rs1, %rs3;
29+
; CHECK-NEXT: add.rn.f16 %rs5, %rs4, %rs2;
30+
; CHECK-NEXT: mov.b32 {%rs6, %rs7}, %r2;
31+
; CHECK-NEXT: add.rn.f16 %rs8, %rs5, %rs6;
32+
; CHECK-NEXT: add.rn.f16 %rs9, %rs8, %rs7;
33+
; CHECK-NEXT: mov.b32 {%rs10, %rs11}, %r3;
34+
; CHECK-NEXT: add.rn.f16 %rs12, %rs9, %rs10;
35+
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs11;
36+
; CHECK-NEXT: mov.b32 {%rs14, %rs15}, %r4;
37+
; CHECK-NEXT: add.rn.f16 %rs16, %rs13, %rs14;
38+
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs15;
3939
; CHECK-NEXT: st.param.b16 [func_retval0], %rs17;
4040
; CHECK-NEXT: ret;
4141
%res = call half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
@@ -174,17 +174,17 @@ define half @reduce_fmul_half(<8 x half> %in) {
174174
; CHECK-EMPTY:
175175
; CHECK-NEXT: // %bb.0:
176176
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fmul_half_param_0];
177-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
178-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
179-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
180-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
181-
; CHECK-NEXT: mul.rn.f16 %rs9, %rs7, %rs8;
182-
; CHECK-NEXT: mul.rn.f16 %rs10, %rs9, %rs5;
183-
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs6;
184-
; CHECK-NEXT: mul.rn.f16 %rs12, %rs11, %rs3;
185-
; CHECK-NEXT: mul.rn.f16 %rs13, %rs12, %rs4;
186-
; CHECK-NEXT: mul.rn.f16 %rs14, %rs13, %rs1;
187-
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs2;
177+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r2;
178+
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r1;
179+
; CHECK-NEXT: mul.rn.f16 %rs5, %rs3, %rs4;
180+
; CHECK-NEXT: mul.rn.f16 %rs6, %rs5, %rs1;
181+
; CHECK-NEXT: mul.rn.f16 %rs7, %rs6, %rs2;
182+
; CHECK-NEXT: mov.b32 {%rs8, %rs9}, %r3;
183+
; CHECK-NEXT: mul.rn.f16 %rs10, %rs7, %rs8;
184+
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs9;
185+
; CHECK-NEXT: mov.b32 {%rs12, %rs13}, %r4;
186+
; CHECK-NEXT: mul.rn.f16 %rs14, %rs11, %rs12;
187+
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs13;
188188
; CHECK-NEXT: st.param.b16 [func_retval0], %rs15;
189189
; CHECK-NEXT: ret;
190190
%res = call half @llvm.vector.reduce.fmul(half 1.0, <8 x half> %in)

0 commit comments

Comments
 (0)