@@ -852,6 +852,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
852
852
if (STI.allowFP16Math () || STI.hasBF16Math ())
853
853
setTargetDAGCombine (ISD::SETCC);
854
854
855
+ // Vector reduction operations. These may be turned into sequential, shuffle,
856
+ // or tree reductions depending on what instructions are available for each
857
+ // type.
858
+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
859
+ MVT EltVT = VT.getVectorElementType ();
860
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
861
+ EltVT == MVT::f64 ) {
862
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
863
+ ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
864
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
865
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
866
+ VT, Custom);
867
+ } else if (EltVT.isScalarInteger ()) {
868
+ setOperationAction (
869
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
870
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
871
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
872
+ VT, Custom);
873
+ }
874
+ }
875
+
855
876
// Promote fp16 arithmetic if fp16 hardware isn't available or the
856
877
// user passed --nvptx-no-fp16-math. The flag is useful because,
857
878
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1109,6 +1130,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1109
1130
MAKE_CASE (NVPTXISD::BFI)
1110
1131
MAKE_CASE (NVPTXISD::PRMT)
1111
1132
MAKE_CASE (NVPTXISD::FCOPYSIGN)
1133
+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1134
+ MAKE_CASE (NVPTXISD::FMINNUM3)
1135
+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1136
+ MAKE_CASE (NVPTXISD::FMINIMUM3)
1112
1137
MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
1113
1138
MAKE_CASE (NVPTXISD::STACKRESTORE)
1114
1139
MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -2108,6 +2133,258 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2108
2133
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2109
2134
}
2110
2135
2136
+ // / A generic routine for constructing a tree reduction on a vector operand.
2137
+ // / This method groups elements bottom-up, progressively building each level.
2138
+ // / Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2139
+ // / adjacent elements are combined first, leading to shorter live ranges. This
2140
+ // / approach makes the most sense if the shuffle reduction would use the same
2141
+ // / amount of registers.
2142
+ // /
2143
+ // / The flags on the original reduction operation will be propagated to
2144
+ // / each scalar operation.
2145
+ static SDValue BuildTreeReduction (
2146
+ const SmallVector<SDValue> &Elements, EVT EltTy,
2147
+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2148
+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2149
+ // Build the reduction tree at each level, starting with all the elements.
2150
+ SmallVector<SDValue> Level = Elements;
2151
+
2152
+ unsigned OpIdx = 0 ;
2153
+ while (Level.size () > 1 ) {
2154
+ // Try to reduce this level using the current operator.
2155
+ const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2156
+
2157
+ // Build the next level by partially reducing all elements.
2158
+ SmallVector<SDValue> ReducedLevel;
2159
+ unsigned I = 0 , E = Level.size ();
2160
+ for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2161
+ // Reduce elements in groups of [DefaultGroupSize], as much as possible.
2162
+ ReducedLevel.push_back (DAG.getNode (
2163
+ DefaultScalarOp, DL, EltTy,
2164
+ ArrayRef<SDValue>(Level).slice (I, DefaultGroupSize), Flags));
2165
+ }
2166
+
2167
+ if (I < E) {
2168
+ // We have leftover elements. Why?
2169
+
2170
+ if (ReducedLevel.empty ()) {
2171
+ // ...because this level is now so small that the current operator is
2172
+ // too big for it. Pick a smaller operator and retry.
2173
+ ++OpIdx;
2174
+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2175
+ continue ;
2176
+ }
2177
+
2178
+ // ...because the operator's required number of inputs doesn't divide
2179
+ // evenly this level. We push this remainder to the next level.
2180
+ for (; I < E; ++I)
2181
+ ReducedLevel.push_back (Level[I]);
2182
+ }
2183
+
2184
+ // Process the next level.
2185
+ Level = ReducedLevel;
2186
+ }
2187
+
2188
+ return *Level.begin ();
2189
+ }
2190
+
2191
+ // / Lower reductions to either a sequence of operations or a tree if
2192
+ // / reassociations are allowed. This method will use larger operations like
2193
+ // / max3/min3 when the target supports them.
2194
+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2195
+ SelectionDAG &DAG) const {
2196
+ SDLoc DL (Op);
2197
+ const SDNodeFlags Flags = Op->getFlags ();
2198
+ SDValue Vector;
2199
+ SDValue Accumulator;
2200
+
2201
+ if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
2202
+ Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
2203
+ // special case with accumulator as first arg
2204
+ Accumulator = Op.getOperand (0 );
2205
+ Vector = Op.getOperand (1 );
2206
+ } else {
2207
+ // default case
2208
+ Vector = Op.getOperand (0 );
2209
+ }
2210
+
2211
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2212
+ const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2213
+ STI.getPTXVersion () >= 88 ;
2214
+
2215
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2216
+ // number of inputs they take.
2217
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2218
+
2219
+ // Whether we can lower to scalar operations in an arbitrary order.
2220
+ bool IsAssociative = allowUnsafeFPMath (DAG.getMachineFunction ());
2221
+
2222
+ // Whether the data type and operation can be represented with fewer ops and
2223
+ // registers in a shuffle reduction.
2224
+ bool PrefersShuffle;
2225
+
2226
+ switch (Op->getOpcode ()) {
2227
+ case ISD::VECREDUCE_FADD:
2228
+ case ISD::VECREDUCE_SEQ_FADD:
2229
+ ScalarOps = {{ISD::FADD, 2 }};
2230
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FADD;
2231
+ // Prefer add.{,b}f16x2 for v2{,b}f16
2232
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2233
+ break ;
2234
+ case ISD::VECREDUCE_FMUL:
2235
+ case ISD::VECREDUCE_SEQ_FMUL:
2236
+ ScalarOps = {{ISD::FMUL, 2 }};
2237
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FMUL;
2238
+ // Prefer mul.{,b}f16x2 for v2{,b}f16
2239
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2240
+ break ;
2241
+ case ISD::VECREDUCE_FMAX:
2242
+ if (CanUseMinMax3)
2243
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2244
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2245
+ // Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2246
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2247
+ IsAssociative |= Flags.hasAllowReassociation ();
2248
+ PrefersShuffle = false ;
2249
+ break ;
2250
+ case ISD::VECREDUCE_FMIN:
2251
+ if (CanUseMinMax3)
2252
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2253
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2254
+ // Definition of minNum in IEEE 754 2008 is non-associative due to handling
2255
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2256
+ IsAssociative |= Flags.hasAllowReassociation ();
2257
+ PrefersShuffle = false ;
2258
+ break ;
2259
+ case ISD::VECREDUCE_FMAXIMUM:
2260
+ if (CanUseMinMax3) {
2261
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2262
+ // Can't use fmax3 in shuffle reduction
2263
+ PrefersShuffle = false ;
2264
+ } else {
2265
+ // Prefer max.{,b}f16x2 for v2{,b}f16
2266
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2267
+ }
2268
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2269
+ IsAssociative = true ;
2270
+ break ;
2271
+ case ISD::VECREDUCE_FMINIMUM:
2272
+ if (CanUseMinMax3) {
2273
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2274
+ // Can't use fmin3 in shuffle reduction
2275
+ PrefersShuffle = false ;
2276
+ } else {
2277
+ // Prefer min.{,b}f16x2 for v2{,b}f16
2278
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2279
+ }
2280
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2281
+ IsAssociative = true ;
2282
+ break ;
2283
+ case ISD::VECREDUCE_ADD:
2284
+ ScalarOps = {{ISD::ADD, 2 }};
2285
+ IsAssociative = true ;
2286
+ // Prefer add.{s,u}16x2 for v2i16
2287
+ PrefersShuffle = EltTy == MVT::i16 ;
2288
+ break ;
2289
+ case ISD::VECREDUCE_MUL:
2290
+ ScalarOps = {{ISD::MUL, 2 }};
2291
+ IsAssociative = true ;
2292
+ // Integer multiply doesn't support packed types
2293
+ PrefersShuffle = false ;
2294
+ break ;
2295
+ case ISD::VECREDUCE_UMAX:
2296
+ ScalarOps = {{ISD::UMAX, 2 }};
2297
+ IsAssociative = true ;
2298
+ // Prefer max.u16x2 for v2i16
2299
+ PrefersShuffle = EltTy == MVT::i16 ;
2300
+ break ;
2301
+ case ISD::VECREDUCE_UMIN:
2302
+ ScalarOps = {{ISD::UMIN, 2 }};
2303
+ IsAssociative = true ;
2304
+ // Prefer min.u16x2 for v2i16
2305
+ PrefersShuffle = EltTy == MVT::i16 ;
2306
+ break ;
2307
+ case ISD::VECREDUCE_SMAX:
2308
+ ScalarOps = {{ISD::SMAX, 2 }};
2309
+ IsAssociative = true ;
2310
+ // Prefer max.s16x2 for v2i16
2311
+ PrefersShuffle = EltTy == MVT::i16 ;
2312
+ break ;
2313
+ case ISD::VECREDUCE_SMIN:
2314
+ ScalarOps = {{ISD::SMIN, 2 }};
2315
+ IsAssociative = true ;
2316
+ // Prefer min.s16x2 for v2i16
2317
+ PrefersShuffle = EltTy == MVT::i16 ;
2318
+ break ;
2319
+ case ISD::VECREDUCE_AND:
2320
+ ScalarOps = {{ISD::AND, 2 }};
2321
+ IsAssociative = true ;
2322
+ // Prefer and.b32 for v2i16.
2323
+ PrefersShuffle = EltTy == MVT::i16 ;
2324
+ break ;
2325
+ case ISD::VECREDUCE_OR:
2326
+ ScalarOps = {{ISD::OR, 2 }};
2327
+ IsAssociative = true ;
2328
+ // Prefer or.b32 for v2i16.
2329
+ PrefersShuffle = EltTy == MVT::i16 ;
2330
+ break ;
2331
+ case ISD::VECREDUCE_XOR:
2332
+ ScalarOps = {{ISD::XOR, 2 }};
2333
+ IsAssociative = true ;
2334
+ // Prefer xor.b32 for v2i16.
2335
+ PrefersShuffle = EltTy == MVT::i16 ;
2336
+ break ;
2337
+ default :
2338
+ llvm_unreachable (" unhandled vecreduce operation" );
2339
+ }
2340
+
2341
+ // We don't expect an accumulator for reassociative vector reduction ops.
2342
+ assert ((!IsAssociative || !Accumulator) && " unexpected accumulator" );
2343
+
2344
+ // If shuffle reduction is preferred, leave it to SelectionDAG.
2345
+ if (IsAssociative && PrefersShuffle)
2346
+ return SDValue ();
2347
+
2348
+ // Otherwise, handle the reduction here.
2349
+ SmallVector<SDValue> Elements;
2350
+ DAG.ExtractVectorElements (Vector, Elements);
2351
+
2352
+ // Lower to tree reduction.
2353
+ if (IsAssociative)
2354
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2355
+
2356
+ // Lower to sequential reduction.
2357
+ EVT VectorTy = Vector.getValueType ();
2358
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2359
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2360
+ // Try to reduce the remaining sequence as much as possible using the
2361
+ // current operator.
2362
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2363
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2364
+
2365
+ if (!Accumulator) {
2366
+ // Try to initialize the accumulator using the current operator.
2367
+ if (I + DefaultGroupSize <= NumElts) {
2368
+ Accumulator = DAG.getNode (
2369
+ DefaultScalarOp, DL, EltTy,
2370
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2371
+ I += DefaultGroupSize;
2372
+ }
2373
+ }
2374
+
2375
+ if (Accumulator) {
2376
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2377
+ SmallVector<SDValue> Operands = {Accumulator};
2378
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2379
+ Operands.push_back (Elements[I + K]);
2380
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2381
+ }
2382
+ }
2383
+ }
2384
+
2385
+ return Accumulator;
2386
+ }
2387
+
2111
2388
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2112
2389
// Handle bitcasting from v2i8 without hitting the default promotion
2113
2390
// strategy which goes through stack memory.
@@ -2940,6 +3217,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2940
3217
return LowerVECTOR_SHUFFLE (Op, DAG);
2941
3218
case ISD::CONCAT_VECTORS:
2942
3219
return LowerCONCAT_VECTORS (Op, DAG);
3220
+ case ISD::VECREDUCE_FADD:
3221
+ case ISD::VECREDUCE_FMUL:
3222
+ case ISD::VECREDUCE_SEQ_FADD:
3223
+ case ISD::VECREDUCE_SEQ_FMUL:
3224
+ case ISD::VECREDUCE_FMAX:
3225
+ case ISD::VECREDUCE_FMIN:
3226
+ case ISD::VECREDUCE_FMAXIMUM:
3227
+ case ISD::VECREDUCE_FMINIMUM:
3228
+ case ISD::VECREDUCE_ADD:
3229
+ case ISD::VECREDUCE_MUL:
3230
+ case ISD::VECREDUCE_UMAX:
3231
+ case ISD::VECREDUCE_UMIN:
3232
+ case ISD::VECREDUCE_SMAX:
3233
+ case ISD::VECREDUCE_SMIN:
3234
+ case ISD::VECREDUCE_AND:
3235
+ case ISD::VECREDUCE_OR:
3236
+ case ISD::VECREDUCE_XOR:
3237
+ return LowerVECREDUCE (Op, DAG);
2943
3238
case ISD::STORE:
2944
3239
return LowerSTORE (Op, DAG);
2945
3240
case ISD::LOAD:
0 commit comments