@@ -852,6 +852,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
852
852
if (STI.allowFP16Math () || STI.hasBF16Math ())
853
853
setTargetDAGCombine (ISD::SETCC);
854
854
855
+ // Vector reduction operations. These may be turned into sequential, shuffle,
856
+ // or tree reductions depending on what instructions are available for each
857
+ // type.
858
+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
859
+ MVT EltVT = VT.getVectorElementType ();
860
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
861
+ EltVT == MVT::f64 ) {
862
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
863
+ ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
864
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
865
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
866
+ VT, Custom);
867
+ } else if (EltVT.isScalarInteger ()) {
868
+ setOperationAction (
869
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
870
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
871
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
872
+ VT, Custom);
873
+ }
874
+ }
875
+
855
876
// Promote fp16 arithmetic if fp16 hardware isn't available or the
856
877
// user passed --nvptx-no-fp16-math. The flag is useful because,
857
878
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1109,6 +1130,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1109
1130
MAKE_CASE (NVPTXISD::BFI)
1110
1131
MAKE_CASE (NVPTXISD::PRMT)
1111
1132
MAKE_CASE (NVPTXISD::FCOPYSIGN)
1133
+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1134
+ MAKE_CASE (NVPTXISD::FMINNUM3)
1135
+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1136
+ MAKE_CASE (NVPTXISD::FMINIMUM3)
1112
1137
MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
1113
1138
MAKE_CASE (NVPTXISD::STACKRESTORE)
1114
1139
MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -2108,6 +2133,259 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2108
2133
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2109
2134
}
2110
2135
2136
+ // / A generic routine for constructing a tree reduction on a vector operand.
2137
+ // / This method groups elements bottom-up, progressively building each level.
2138
+ // / Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2139
+ // / adjacent elements are combined first, leading to shorter live ranges. This
2140
+ // / approach makes the most sense if the shuffle reduction would use the same
2141
+ // / amount of registers.
2142
+ // /
2143
+ // / The flags on the original reduction operation will be propagated to
2144
+ // / each scalar operation.
2145
+ static SDValue BuildTreeReduction (
2146
+ const SmallVector<SDValue> &Elements, EVT EltTy,
2147
+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2148
+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2149
+ // Build the reduction tree at each level, starting with all the elements.
2150
+ SmallVector<SDValue> Level = Elements;
2151
+
2152
+ unsigned OpIdx = 0 ;
2153
+ while (Level.size () > 1 ) {
2154
+ // Try to reduce this level using the current operator.
2155
+ const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2156
+
2157
+ // Build the next level by partially reducing all elements.
2158
+ SmallVector<SDValue> ReducedLevel;
2159
+ unsigned I = 0 , E = Level.size ();
2160
+ for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2161
+ // Reduce elements in groups of [DefaultGroupSize], as much as possible.
2162
+ ReducedLevel.push_back (DAG.getNode (
2163
+ DefaultScalarOp, DL, EltTy,
2164
+ ArrayRef<SDValue>(Level).slice (I, DefaultGroupSize), Flags));
2165
+ }
2166
+
2167
+ if (I < E) {
2168
+ // Handle leftover elements.
2169
+
2170
+ if (ReducedLevel.empty ()) {
2171
+ // We didn't reduce anything at this level. We need to pick a smaller
2172
+ // operator.
2173
+ ++OpIdx;
2174
+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2175
+ continue ;
2176
+ }
2177
+
2178
+ // We reduced some things but there's still more left, meaning the
2179
+ // operator's number of inputs doesn't evenly divide this level size. Move
2180
+ // these elements to the next level.
2181
+ for (; I < E; ++I)
2182
+ ReducedLevel.push_back (Level[I]);
2183
+ }
2184
+
2185
+ // Process the next level.
2186
+ Level = ReducedLevel;
2187
+ }
2188
+
2189
+ return *Level.begin ();
2190
+ }
2191
+
2192
+ // / Lower reductions to either a sequence of operations or a tree if
2193
+ // / reassociations are allowed. This method will use larger operations like
2194
+ // / max3/min3 when the target supports them.
2195
+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2196
+ SelectionDAG &DAG) const {
2197
+ SDLoc DL (Op);
2198
+ const SDNodeFlags Flags = Op->getFlags ();
2199
+ SDValue Vector;
2200
+ SDValue Accumulator;
2201
+
2202
+ if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
2203
+ Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
2204
+ // special case with accumulator as first arg
2205
+ Accumulator = Op.getOperand (0 );
2206
+ Vector = Op.getOperand (1 );
2207
+ } else {
2208
+ // default case
2209
+ Vector = Op.getOperand (0 );
2210
+ }
2211
+
2212
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2213
+ const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2214
+ STI.getPTXVersion () >= 88 ;
2215
+
2216
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2217
+ // number of inputs they take.
2218
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2219
+
2220
+ // Whether we can lower to scalar operations in an arbitrary order.
2221
+ bool IsAssociative = allowUnsafeFPMath (DAG.getMachineFunction ());
2222
+
2223
+ // Whether the data type and operation can be represented with fewer ops and
2224
+ // registers in a shuffle reduction.
2225
+ bool PrefersShuffle;
2226
+
2227
+ switch (Op->getOpcode ()) {
2228
+ case ISD::VECREDUCE_FADD:
2229
+ case ISD::VECREDUCE_SEQ_FADD:
2230
+ ScalarOps = {{ISD::FADD, 2 }};
2231
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FADD;
2232
+ // Prefer add.{,b}f16x2 for v2{,b}f16
2233
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2234
+ break ;
2235
+ case ISD::VECREDUCE_FMUL:
2236
+ case ISD::VECREDUCE_SEQ_FMUL:
2237
+ ScalarOps = {{ISD::FMUL, 2 }};
2238
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FMUL;
2239
+ // Prefer mul.{,b}f16x2 for v2{,b}f16
2240
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2241
+ break ;
2242
+ case ISD::VECREDUCE_FMAX:
2243
+ if (CanUseMinMax3)
2244
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2245
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2246
+ // Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2247
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2248
+ IsAssociative |= Flags.hasAllowReassociation ();
2249
+ PrefersShuffle = false ;
2250
+ break ;
2251
+ case ISD::VECREDUCE_FMIN:
2252
+ if (CanUseMinMax3)
2253
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2254
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2255
+ // Definition of minNum in IEEE 754 2008 is non-associative due to handling
2256
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2257
+ IsAssociative |= Flags.hasAllowReassociation ();
2258
+ PrefersShuffle = false ;
2259
+ break ;
2260
+ case ISD::VECREDUCE_FMAXIMUM:
2261
+ if (CanUseMinMax3) {
2262
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2263
+ // Can't use fmax3 in shuffle reduction
2264
+ PrefersShuffle = false ;
2265
+ } else {
2266
+ // Prefer max.{,b}f16x2 for v2{,b}f16
2267
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2268
+ }
2269
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2270
+ IsAssociative = true ;
2271
+ break ;
2272
+ case ISD::VECREDUCE_FMINIMUM:
2273
+ if (CanUseMinMax3) {
2274
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2275
+ // Can't use fmin3 in shuffle reduction
2276
+ PrefersShuffle = false ;
2277
+ } else {
2278
+ // Prefer min.{,b}f16x2 for v2{,b}f16
2279
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2280
+ }
2281
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2282
+ IsAssociative = true ;
2283
+ break ;
2284
+ case ISD::VECREDUCE_ADD:
2285
+ ScalarOps = {{ISD::ADD, 2 }};
2286
+ IsAssociative = true ;
2287
+ // Prefer add.{s,u}16x2 for v2i16
2288
+ PrefersShuffle = EltTy == MVT::i16 ;
2289
+ break ;
2290
+ case ISD::VECREDUCE_MUL:
2291
+ ScalarOps = {{ISD::MUL, 2 }};
2292
+ IsAssociative = true ;
2293
+ // Integer multiply doesn't support packed types
2294
+ PrefersShuffle = false ;
2295
+ break ;
2296
+ case ISD::VECREDUCE_UMAX:
2297
+ ScalarOps = {{ISD::UMAX, 2 }};
2298
+ IsAssociative = true ;
2299
+ // Prefer max.u16x2 for v2i16
2300
+ PrefersShuffle = EltTy == MVT::i16 ;
2301
+ break ;
2302
+ case ISD::VECREDUCE_UMIN:
2303
+ ScalarOps = {{ISD::UMIN, 2 }};
2304
+ IsAssociative = true ;
2305
+ // Prefer min.u16x2 for v2i16
2306
+ PrefersShuffle = EltTy == MVT::i16 ;
2307
+ break ;
2308
+ case ISD::VECREDUCE_SMAX:
2309
+ ScalarOps = {{ISD::SMAX, 2 }};
2310
+ IsAssociative = true ;
2311
+ // Prefer max.s16x2 for v2i16
2312
+ PrefersShuffle = EltTy == MVT::i16 ;
2313
+ break ;
2314
+ case ISD::VECREDUCE_SMIN:
2315
+ ScalarOps = {{ISD::SMIN, 2 }};
2316
+ IsAssociative = true ;
2317
+ // Prefer min.s16x2 for v2i16
2318
+ PrefersShuffle = EltTy == MVT::i16 ;
2319
+ break ;
2320
+ case ISD::VECREDUCE_AND:
2321
+ ScalarOps = {{ISD::AND, 2 }};
2322
+ IsAssociative = true ;
2323
+ // Prefer and.b32 for v2i16.
2324
+ PrefersShuffle = EltTy == MVT::i16 ;
2325
+ break ;
2326
+ case ISD::VECREDUCE_OR:
2327
+ ScalarOps = {{ISD::OR, 2 }};
2328
+ IsAssociative = true ;
2329
+ // Prefer or.b32 for v2i16.
2330
+ PrefersShuffle = EltTy == MVT::i16 ;
2331
+ break ;
2332
+ case ISD::VECREDUCE_XOR:
2333
+ ScalarOps = {{ISD::XOR, 2 }};
2334
+ IsAssociative = true ;
2335
+ // Prefer xor.b32 for v2i16.
2336
+ PrefersShuffle = EltTy == MVT::i16 ;
2337
+ break ;
2338
+ default :
2339
+ llvm_unreachable (" unhandled vecreduce operation" );
2340
+ }
2341
+
2342
+ // We don't expect an accumulator for reassociative vector reduction ops.
2343
+ assert ((!IsAssociative || !Accumulator) && " unexpected accumulator" );
2344
+
2345
+ // If shuffle reduction is preferred, leave it to SelectionDAG.
2346
+ if (IsAssociative && PrefersShuffle)
2347
+ return SDValue ();
2348
+
2349
+ // Otherwise, handle the reduction here.
2350
+ SmallVector<SDValue> Elements;
2351
+ DAG.ExtractVectorElements (Vector, Elements);
2352
+
2353
+ // Lower to tree reduction.
2354
+ if (IsAssociative)
2355
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2356
+
2357
+ // Lower to sequential reduction.
2358
+ EVT VectorTy = Vector.getValueType ();
2359
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2360
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2361
+ // Try to reduce the remaining sequence as much as possible using the
2362
+ // current operator.
2363
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2364
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2365
+
2366
+ if (!Accumulator) {
2367
+ // Try to initialize the accumulator using the current operator.
2368
+ if (I + DefaultGroupSize <= NumElts) {
2369
+ Accumulator = DAG.getNode (
2370
+ DefaultScalarOp, DL, EltTy,
2371
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2372
+ I += DefaultGroupSize;
2373
+ }
2374
+ }
2375
+
2376
+ if (Accumulator) {
2377
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2378
+ SmallVector<SDValue> Operands = {Accumulator};
2379
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2380
+ Operands.push_back (Elements[I + K]);
2381
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2382
+ }
2383
+ }
2384
+ }
2385
+
2386
+ return Accumulator;
2387
+ }
2388
+
2111
2389
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2112
2390
// Handle bitcasting from v2i8 without hitting the default promotion
2113
2391
// strategy which goes through stack memory.
@@ -2940,6 +3218,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2940
3218
return LowerVECTOR_SHUFFLE (Op, DAG);
2941
3219
case ISD::CONCAT_VECTORS:
2942
3220
return LowerCONCAT_VECTORS (Op, DAG);
3221
+ case ISD::VECREDUCE_FADD:
3222
+ case ISD::VECREDUCE_FMUL:
3223
+ case ISD::VECREDUCE_SEQ_FADD:
3224
+ case ISD::VECREDUCE_SEQ_FMUL:
3225
+ case ISD::VECREDUCE_FMAX:
3226
+ case ISD::VECREDUCE_FMIN:
3227
+ case ISD::VECREDUCE_FMAXIMUM:
3228
+ case ISD::VECREDUCE_FMINIMUM:
3229
+ case ISD::VECREDUCE_ADD:
3230
+ case ISD::VECREDUCE_MUL:
3231
+ case ISD::VECREDUCE_UMAX:
3232
+ case ISD::VECREDUCE_UMIN:
3233
+ case ISD::VECREDUCE_SMAX:
3234
+ case ISD::VECREDUCE_SMIN:
3235
+ case ISD::VECREDUCE_AND:
3236
+ case ISD::VECREDUCE_OR:
3237
+ case ISD::VECREDUCE_XOR:
3238
+ return LowerVECREDUCE (Op, DAG);
2943
3239
case ISD::STORE:
2944
3240
return LowerSTORE (Op, DAG);
2945
3241
case ISD::LOAD:
0 commit comments