@@ -2225,19 +2225,25 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2225
2225
}
2226
2226
2227
2227
// / A generic routine for constructing a tree reduction on a vector operand.
2228
- // / This method differs from iterative splitting in DAGTypeLegalizer by
2229
- // / progressively grouping elements bottom-up.
2228
+ // / This method groups elements bottom-up, progressively building each level.
2229
+ // / This approach differs from top-down iterative splitting used in
2230
+ // / DAGTypeLegalizer and ExpandReductions.
2231
+ // /
2232
+ // / Also, the flags on the original reduction operation will be propagated to
2233
+ // / each scalar operation.
2230
2234
static SDValue BuildTreeReduction (
2231
2235
const SmallVector<SDValue> &Elements, EVT EltTy,
2232
2236
ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2233
2237
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2234
- // now build the computation graph in place at each level
2238
+ // Build the reduction tree at each level, starting with all the elements.
2235
2239
SmallVector<SDValue> Level = Elements;
2240
+
2236
2241
unsigned OpIdx = 0 ;
2237
2242
while (Level.size () > 1 ) {
2243
+ // Try to reduce this level using the current operator.
2238
2244
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2239
2245
2240
- // partially reduce all elements in level
2246
+ // Build the next level by partially reducing all elements.
2241
2247
SmallVector<SDValue> ReducedLevel;
2242
2248
unsigned I = 0 , E = Level.size ();
2243
2249
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
@@ -2248,18 +2254,23 @@ static SDValue BuildTreeReduction(
2248
2254
}
2249
2255
2250
2256
if (I < E) {
2257
+ // We have leftover elements. Why?
2258
+
2251
2259
if (ReducedLevel.empty ()) {
2252
- // The current operator requires more inputs than there are operands at
2253
- // this level . Pick a smaller operator and retry.
2260
+ // ...because this level is now so small that the current operator is
2261
+ // too big for it . Pick a smaller operator and retry.
2254
2262
++OpIdx;
2255
2263
assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2256
2264
continue ;
2257
2265
}
2258
2266
2259
- // Otherwise, we just have a remainder, which we push to the next level.
2267
+ // ...because the operator's required number of inputs doesn't divide
2268
+ // evenly this level. We push this remainder to the next level.
2260
2269
for (; I < E; ++I)
2261
2270
ReducedLevel.push_back (Level[I]);
2262
2271
}
2272
+
2273
+ // Process the next level.
2263
2274
Level = ReducedLevel;
2264
2275
}
2265
2276
@@ -2275,6 +2286,7 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2275
2286
const SDNodeFlags Flags = Op->getFlags ();
2276
2287
SDValue Vector;
2277
2288
SDValue Accumulator;
2289
+
2278
2290
if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
2279
2291
Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
2280
2292
// special case with accumulator as first arg
@@ -2284,85 +2296,94 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2284
2296
// default case
2285
2297
Vector = Op.getOperand (0 );
2286
2298
}
2299
+
2287
2300
EVT EltTy = Vector.getValueType ().getVectorElementType ();
2288
2301
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2289
2302
STI.getPTXVersion () >= 88 ;
2290
2303
2291
2304
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2292
2305
// number of inputs they take.
2293
2306
SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2294
- bool IsReassociatable;
2307
+
2308
+ // Whether we can lower to scalar operations in an arbitrary order.
2309
+ bool IsAssociative;
2295
2310
2296
2311
switch (Op->getOpcode ()) {
2297
2312
case ISD::VECREDUCE_FADD:
2298
2313
case ISD::VECREDUCE_SEQ_FADD:
2299
2314
ScalarOps = {{ISD::FADD, 2 }};
2300
- IsReassociatable = false ;
2315
+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FADD ;
2301
2316
break ;
2302
2317
case ISD::VECREDUCE_FMUL:
2303
2318
case ISD::VECREDUCE_SEQ_FMUL:
2304
2319
ScalarOps = {{ISD::FMUL, 2 }};
2305
- IsReassociatable = false ;
2320
+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FMUL ;
2306
2321
break ;
2307
2322
case ISD::VECREDUCE_FMAX:
2308
2323
if (CanUseMinMax3)
2309
2324
ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2310
2325
ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2311
- IsReassociatable = false ;
2326
+ // Definition of maxNum in IEEE 754 2008 is non-associative, but only
2327
+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2328
+ // sNaNs.
2329
+ IsAssociative = true ;
2312
2330
break ;
2313
2331
case ISD::VECREDUCE_FMIN:
2314
2332
if (CanUseMinMax3)
2315
2333
ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2316
2334
ScalarOps.push_back ({ISD::FMINNUM, 2 });
2317
- IsReassociatable = false ;
2335
+ // Definition of minNum in IEEE 754 2008 is non-associative, but only
2336
+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2337
+ // sNaNs.
2338
+ IsAssociative = true ;
2318
2339
break ;
2319
2340
case ISD::VECREDUCE_FMAXIMUM:
2320
2341
if (CanUseMinMax3)
2321
2342
ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2322
2343
ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2323
- IsReassociatable = false ;
2344
+ IsAssociative = true ;
2324
2345
break ;
2325
2346
case ISD::VECREDUCE_FMINIMUM:
2326
2347
if (CanUseMinMax3)
2327
2348
ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2328
2349
ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2329
- IsReassociatable = false ;
2350
+ IsAssociative = true ;
2330
2351
break ;
2331
2352
case ISD::VECREDUCE_ADD:
2332
2353
ScalarOps = {{ISD::ADD, 2 }};
2333
- IsReassociatable = true ;
2354
+ IsAssociative = true ;
2334
2355
break ;
2335
2356
case ISD::VECREDUCE_MUL:
2336
2357
ScalarOps = {{ISD::MUL, 2 }};
2337
- IsReassociatable = true ;
2358
+ IsAssociative = true ;
2338
2359
break ;
2339
2360
case ISD::VECREDUCE_UMAX:
2340
2361
ScalarOps = {{ISD::UMAX, 2 }};
2341
- IsReassociatable = true ;
2362
+ IsAssociative = true ;
2342
2363
break ;
2343
2364
case ISD::VECREDUCE_UMIN:
2344
2365
ScalarOps = {{ISD::UMIN, 2 }};
2345
- IsReassociatable = true ;
2366
+ IsAssociative = true ;
2346
2367
break ;
2347
2368
case ISD::VECREDUCE_SMAX:
2348
2369
ScalarOps = {{ISD::SMAX, 2 }};
2349
- IsReassociatable = true ;
2370
+ IsAssociative = true ;
2350
2371
break ;
2351
2372
case ISD::VECREDUCE_SMIN:
2352
2373
ScalarOps = {{ISD::SMIN, 2 }};
2353
- IsReassociatable = true ;
2374
+ IsAssociative = true ;
2354
2375
break ;
2355
2376
case ISD::VECREDUCE_AND:
2356
2377
ScalarOps = {{ISD::AND, 2 }};
2357
- IsReassociatable = true ;
2378
+ IsAssociative = true ;
2358
2379
break ;
2359
2380
case ISD::VECREDUCE_OR:
2360
2381
ScalarOps = {{ISD::OR, 2 }};
2361
- IsReassociatable = true ;
2382
+ IsAssociative = true ;
2362
2383
break ;
2363
2384
case ISD::VECREDUCE_XOR:
2364
2385
ScalarOps = {{ISD::XOR, 2 }};
2365
- IsReassociatable = true ;
2386
+ IsAssociative = true ;
2366
2387
break ;
2367
2388
default :
2368
2389
llvm_unreachable (" unhandled vecreduce operation" );
@@ -2379,18 +2400,21 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2379
2400
}
2380
2401
2381
2402
// Lower to tree reduction.
2382
- if (IsReassociatable || Flags. hasAllowReassociation ( )) {
2383
- // we don't expect an accumulator for reassociatable vector reduction ops
2403
+ if (IsAssociative || allowUnsafeFPMath (DAG. getMachineFunction () )) {
2404
+ // we don't expect an accumulator for reassociative vector reduction ops
2384
2405
assert (!Accumulator && " unexpected accumulator" );
2385
2406
return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2386
2407
}
2387
2408
2388
2409
// Lower to sequential reduction.
2389
2410
for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2411
+ // Try to reduce the remaining sequence as much as possible using the
2412
+ // current operator.
2390
2413
assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2391
2414
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2392
2415
2393
2416
if (!Accumulator) {
2417
+ // Try to initialize the accumulator using the current operator.
2394
2418
if (I + DefaultGroupSize <= NumElts) {
2395
2419
Accumulator = DAG.getNode (
2396
2420
DefaultScalarOp, DL, EltTy,
0 commit comments