Skip to content

Commit b13ed82

Browse files
committed
[NVPTX] expand associativity to fmax / fmin and add comments
1 parent 1a01686 commit b13ed82

File tree

2 files changed

+133
-109
lines changed

2 files changed

+133
-109
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,19 +2225,25 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
22252225
}
22262226

22272227
/// A generic routine for constructing a tree reduction on a vector operand.
2228-
/// This method differs from iterative splitting in DAGTypeLegalizer by
2229-
/// progressively grouping elements bottom-up.
2228+
/// This method groups elements bottom-up, progressively building each level.
2229+
/// This approach differs from top-down iterative splitting used in
2230+
/// DAGTypeLegalizer and ExpandReductions.
2231+
///
2232+
/// Also, the flags on the original reduction operation will be propagated to
2233+
/// each scalar operation.
22302234
static SDValue BuildTreeReduction(
22312235
const SmallVector<SDValue> &Elements, EVT EltTy,
22322236
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
22332237
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2234-
// now build the computation graph in place at each level
2238+
// Build the reduction tree at each level, starting with all the elements.
22352239
SmallVector<SDValue> Level = Elements;
2240+
22362241
unsigned OpIdx = 0;
22372242
while (Level.size() > 1) {
2243+
// Try to reduce this level using the current operator.
22382244
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
22392245

2240-
// partially reduce all elements in level
2246+
// Build the next level by partially reducing all elements.
22412247
SmallVector<SDValue> ReducedLevel;
22422248
unsigned I = 0, E = Level.size();
22432249
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
@@ -2248,18 +2254,23 @@ static SDValue BuildTreeReduction(
22482254
}
22492255

22502256
if (I < E) {
2257+
// We have leftover elements. Why?
2258+
22512259
if (ReducedLevel.empty()) {
2252-
// The current operator requires more inputs than there are operands at
2253-
// this level. Pick a smaller operator and retry.
2260+
// ...because this level is now so small that the current operator is
2261+
// too big for it. Pick a smaller operator and retry.
22542262
++OpIdx;
22552263
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
22562264
continue;
22572265
}
22582266

2259-
// Otherwise, we just have a remainder, which we push to the next level.
2267+
// ...because the operator's required number of inputs doesn't divide
2268+
// evenly this level. We push this remainder to the next level.
22602269
for (; I < E; ++I)
22612270
ReducedLevel.push_back(Level[I]);
22622271
}
2272+
2273+
// Process the next level.
22632274
Level = ReducedLevel;
22642275
}
22652276

@@ -2275,6 +2286,7 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22752286
const SDNodeFlags Flags = Op->getFlags();
22762287
SDValue Vector;
22772288
SDValue Accumulator;
2289+
22782290
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
22792291
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
22802292
// special case with accumulator as first arg
@@ -2284,85 +2296,94 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22842296
// default case
22852297
Vector = Op.getOperand(0);
22862298
}
2299+
22872300
EVT EltTy = Vector.getValueType().getVectorElementType();
22882301
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22892302
STI.getPTXVersion() >= 88;
22902303

22912304
// A list of SDNode opcodes with equivalent semantics, sorted descending by
22922305
// number of inputs they take.
22932306
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2294-
bool IsReassociatable;
2307+
2308+
// Whether we can lower to scalar operations in an arbitrary order.
2309+
bool IsAssociative;
22952310

22962311
switch (Op->getOpcode()) {
22972312
case ISD::VECREDUCE_FADD:
22982313
case ISD::VECREDUCE_SEQ_FADD:
22992314
ScalarOps = {{ISD::FADD, 2}};
2300-
IsReassociatable = false;
2315+
IsAssociative = Op->getOpcode() == ISD::VECREDUCE_FADD;
23012316
break;
23022317
case ISD::VECREDUCE_FMUL:
23032318
case ISD::VECREDUCE_SEQ_FMUL:
23042319
ScalarOps = {{ISD::FMUL, 2}};
2305-
IsReassociatable = false;
2320+
IsAssociative = Op->getOpcode() == ISD::VECREDUCE_FMUL;
23062321
break;
23072322
case ISD::VECREDUCE_FMAX:
23082323
if (CanUseMinMax3)
23092324
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
23102325
ScalarOps.push_back({ISD::FMAXNUM, 2});
2311-
IsReassociatable = false;
2326+
// Definition of maxNum in IEEE 754 2008 is non-associative, but only
2327+
// because of how sNaNs are treated. However, NVIDIA GPUs don't support
2328+
// sNaNs.
2329+
IsAssociative = true;
23122330
break;
23132331
case ISD::VECREDUCE_FMIN:
23142332
if (CanUseMinMax3)
23152333
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
23162334
ScalarOps.push_back({ISD::FMINNUM, 2});
2317-
IsReassociatable = false;
2335+
// Definition of minNum in IEEE 754 2008 is non-associative, but only
2336+
// because of how sNaNs are treated. However, NVIDIA GPUs don't support
2337+
// sNaNs.
2338+
IsAssociative = true;
23182339
break;
23192340
case ISD::VECREDUCE_FMAXIMUM:
23202341
if (CanUseMinMax3)
23212342
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
23222343
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2323-
IsReassociatable = false;
2344+
IsAssociative = true;
23242345
break;
23252346
case ISD::VECREDUCE_FMINIMUM:
23262347
if (CanUseMinMax3)
23272348
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
23282349
ScalarOps.push_back({ISD::FMINIMUM, 2});
2329-
IsReassociatable = false;
2350+
IsAssociative = true;
23302351
break;
23312352
case ISD::VECREDUCE_ADD:
23322353
ScalarOps = {{ISD::ADD, 2}};
2333-
IsReassociatable = true;
2354+
IsAssociative = true;
23342355
break;
23352356
case ISD::VECREDUCE_MUL:
23362357
ScalarOps = {{ISD::MUL, 2}};
2337-
IsReassociatable = true;
2358+
IsAssociative = true;
23382359
break;
23392360
case ISD::VECREDUCE_UMAX:
23402361
ScalarOps = {{ISD::UMAX, 2}};
2341-
IsReassociatable = true;
2362+
IsAssociative = true;
23422363
break;
23432364
case ISD::VECREDUCE_UMIN:
23442365
ScalarOps = {{ISD::UMIN, 2}};
2345-
IsReassociatable = true;
2366+
IsAssociative = true;
23462367
break;
23472368
case ISD::VECREDUCE_SMAX:
23482369
ScalarOps = {{ISD::SMAX, 2}};
2349-
IsReassociatable = true;
2370+
IsAssociative = true;
23502371
break;
23512372
case ISD::VECREDUCE_SMIN:
23522373
ScalarOps = {{ISD::SMIN, 2}};
2353-
IsReassociatable = true;
2374+
IsAssociative = true;
23542375
break;
23552376
case ISD::VECREDUCE_AND:
23562377
ScalarOps = {{ISD::AND, 2}};
2357-
IsReassociatable = true;
2378+
IsAssociative = true;
23582379
break;
23592380
case ISD::VECREDUCE_OR:
23602381
ScalarOps = {{ISD::OR, 2}};
2361-
IsReassociatable = true;
2382+
IsAssociative = true;
23622383
break;
23632384
case ISD::VECREDUCE_XOR:
23642385
ScalarOps = {{ISD::XOR, 2}};
2365-
IsReassociatable = true;
2386+
IsAssociative = true;
23662387
break;
23672388
default:
23682389
llvm_unreachable("unhandled vecreduce operation");
@@ -2379,18 +2400,21 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23792400
}
23802401

23812402
// Lower to tree reduction.
2382-
if (IsReassociatable || Flags.hasAllowReassociation()) {
2383-
// we don't expect an accumulator for reassociatable vector reduction ops
2403+
if (IsAssociative || allowUnsafeFPMath(DAG.getMachineFunction())) {
2404+
// we don't expect an accumulator for reassociative vector reduction ops
23842405
assert(!Accumulator && "unexpected accumulator");
23852406
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
23862407
}
23872408

23882409
// Lower to sequential reduction.
23892410
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2411+
// Try to reduce the remaining sequence as much as possible using the
2412+
// current operator.
23902413
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
23912414
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
23922415

23932416
if (!Accumulator) {
2417+
// Try to initialize the accumulator using the current operator.
23942418
if (I + DefaultGroupSize <= NumElts) {
23952419
Accumulator = DAG.getNode(
23962420
DefaultScalarOp, DL, EltTy,

0 commit comments

Comments
 (0)