Skip to content

Commit 24925dc

Browse files
committed
[NVPTX] lower VECREDUCE intrinsics to tree reduction
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in PTX 8.8. This method of tree reduction has a few benefits over the default in DAGTypeLegalizer: - The default shuffle reduction progressively halves and partially reduces the vector down until we reach a single element. This produces a sequence of operations that combine disparate elements of the vector. For example, `vecreduce_fadd <4 x f32><a b c d>` will give `(a + c) + (b + d)`, whereas the tree reduction produces (a + b) + (c + d) by grouping nearby elements together first. Both use the same number of registers, but the shuffle reduction has longer live ranges. The same example is graphed below. Note we hold onto 3 registers for 2 cycles in the shuffle reduction and 1 cycle in tree reduction. (shuffle reduction) PTX: %r1 = add.f32 a, c %r2 = add.f32 b, d %r3 = add.f32 %r1, %r3 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | | %r1 = add.f32 a, c | %r2 = add.f32 b, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | a b c [3R] | b d %r1 [3R] | %r1 %r2 [2R] | %r3 [1R] | (tree reduction) PTX: %r1 = add.f32 a, b %r2 = add.f32 c, d %r3 = add.f32 %r1, %r2 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | %r1 = add.f32 a, b | | %r2 = add.f32 c, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | c %r1 [2R] | c %r1 d [3R] | %r1 %r2 [2R] | %r3 [1R] | - The shuffle reduction cannot easily support fmax3/fmin3 because it progressively halves the input vector. - Faster compile time. Happens in one pass over the intrinsic, rather than O(N) passes if iteratively splitting the vector operands.
1 parent 82acd8c commit 24925dc

File tree

5 files changed

+1122
-706
lines changed

5 files changed

+1122
-706
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
852852
if (STI.allowFP16Math() || STI.hasBF16Math())
853853
setTargetDAGCombine(ISD::SETCC);
854854

855+
// Vector reduction operations. These may be turned into sequential, shuffle,
856+
// or tree reductions depending on what instructions are available for each
857+
// type.
858+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
859+
MVT EltVT = VT.getVectorElementType();
860+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
861+
EltVT == MVT::f64) {
862+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
863+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
864+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
865+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
866+
VT, Custom);
867+
} else if (EltVT.isScalarInteger()) {
868+
setOperationAction(
869+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
870+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
871+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
872+
VT, Custom);
873+
}
874+
}
875+
855876
// Promote fp16 arithmetic if fp16 hardware isn't available or the
856877
// user passed --nvptx-no-fp16-math. The flag is useful because,
857878
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1109,6 +1130,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11091130
MAKE_CASE(NVPTXISD::BFI)
11101131
MAKE_CASE(NVPTXISD::PRMT)
11111132
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1133+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1134+
MAKE_CASE(NVPTXISD::FMINNUM3)
1135+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1136+
MAKE_CASE(NVPTXISD::FMINIMUM3)
11121137
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
11131138
MAKE_CASE(NVPTXISD::STACKRESTORE)
11141139
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2108,6 +2133,258 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21082133
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21092134
}
21102135

2136+
/// A generic routine for constructing a tree reduction on a vector operand.
2137+
/// This method groups elements bottom-up, progressively building each level.
2138+
/// Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2139+
/// adjacent elements are combined first, leading to shorter live ranges. This
2140+
/// approach makes the most sense if the shuffle reduction would use the same
2141+
/// amount of registers.
2142+
///
2143+
/// The flags on the original reduction operation will be propagated to
2144+
/// each scalar operation.
2145+
static SDValue BuildTreeReduction(
2146+
const SmallVector<SDValue> &Elements, EVT EltTy,
2147+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2148+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2149+
// Build the reduction tree at each level, starting with all the elements.
2150+
SmallVector<SDValue> Level = Elements;
2151+
2152+
unsigned OpIdx = 0;
2153+
while (Level.size() > 1) {
2154+
// Try to reduce this level using the current operator.
2155+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2156+
2157+
// Build the next level by partially reducing all elements.
2158+
SmallVector<SDValue> ReducedLevel;
2159+
unsigned I = 0, E = Level.size();
2160+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2161+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2162+
ReducedLevel.push_back(DAG.getNode(
2163+
DefaultScalarOp, DL, EltTy,
2164+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2165+
}
2166+
2167+
if (I < E) {
2168+
// We have leftover elements. Why?
2169+
2170+
if (ReducedLevel.empty()) {
2171+
// ...because this level is now so small that the current operator is
2172+
// too big for it. Pick a smaller operator and retry.
2173+
++OpIdx;
2174+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2175+
continue;
2176+
}
2177+
2178+
// ...because the operator's required number of inputs doesn't divide
2179+
// evenly this level. We push this remainder to the next level.
2180+
for (; I < E; ++I)
2181+
ReducedLevel.push_back(Level[I]);
2182+
}
2183+
2184+
// Process the next level.
2185+
Level = ReducedLevel;
2186+
}
2187+
2188+
return *Level.begin();
2189+
}
2190+
2191+
/// Lower reductions to either a sequence of operations or a tree if
2192+
/// reassociations are allowed. This method will use larger operations like
2193+
/// max3/min3 when the target supports them.
2194+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2195+
SelectionDAG &DAG) const {
2196+
SDLoc DL(Op);
2197+
const SDNodeFlags Flags = Op->getFlags();
2198+
SDValue Vector;
2199+
SDValue Accumulator;
2200+
2201+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2202+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2203+
// special case with accumulator as first arg
2204+
Accumulator = Op.getOperand(0);
2205+
Vector = Op.getOperand(1);
2206+
} else {
2207+
// default case
2208+
Vector = Op.getOperand(0);
2209+
}
2210+
2211+
EVT EltTy = Vector.getValueType().getVectorElementType();
2212+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2213+
STI.getPTXVersion() >= 88;
2214+
2215+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2216+
// number of inputs they take.
2217+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2218+
2219+
// Whether we can lower to scalar operations in an arbitrary order.
2220+
bool IsAssociative = allowUnsafeFPMath(DAG.getMachineFunction());
2221+
2222+
// Whether the data type and operation can be represented with fewer ops and
2223+
// registers in a shuffle reduction.
2224+
bool PrefersShuffle;
2225+
2226+
switch (Op->getOpcode()) {
2227+
case ISD::VECREDUCE_FADD:
2228+
case ISD::VECREDUCE_SEQ_FADD:
2229+
ScalarOps = {{ISD::FADD, 2}};
2230+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FADD;
2231+
// Prefer add.{,b}f16x2 for v2{,b}f16
2232+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2233+
break;
2234+
case ISD::VECREDUCE_FMUL:
2235+
case ISD::VECREDUCE_SEQ_FMUL:
2236+
ScalarOps = {{ISD::FMUL, 2}};
2237+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FMUL;
2238+
// Prefer mul.{,b}f16x2 for v2{,b}f16
2239+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2240+
break;
2241+
case ISD::VECREDUCE_FMAX:
2242+
if (CanUseMinMax3)
2243+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2244+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2245+
// Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2246+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2247+
IsAssociative |= Flags.hasAllowReassociation();
2248+
PrefersShuffle = false;
2249+
break;
2250+
case ISD::VECREDUCE_FMIN:
2251+
if (CanUseMinMax3)
2252+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2253+
ScalarOps.push_back({ISD::FMINNUM, 2});
2254+
// Definition of minNum in IEEE 754 2008 is non-associative due to handling
2255+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2256+
IsAssociative |= Flags.hasAllowReassociation();
2257+
PrefersShuffle = false;
2258+
break;
2259+
case ISD::VECREDUCE_FMAXIMUM:
2260+
if (CanUseMinMax3) {
2261+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2262+
// Can't use fmax3 in shuffle reduction
2263+
PrefersShuffle = false;
2264+
} else {
2265+
// Prefer max.{,b}f16x2 for v2{,b}f16
2266+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2267+
}
2268+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2269+
IsAssociative = true;
2270+
break;
2271+
case ISD::VECREDUCE_FMINIMUM:
2272+
if (CanUseMinMax3) {
2273+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2274+
// Can't use fmin3 in shuffle reduction
2275+
PrefersShuffle = false;
2276+
} else {
2277+
// Prefer min.{,b}f16x2 for v2{,b}f16
2278+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2279+
}
2280+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2281+
IsAssociative = true;
2282+
break;
2283+
case ISD::VECREDUCE_ADD:
2284+
ScalarOps = {{ISD::ADD, 2}};
2285+
IsAssociative = true;
2286+
// Prefer add.{s,u}16x2 for v2i16
2287+
PrefersShuffle = EltTy == MVT::i16;
2288+
break;
2289+
case ISD::VECREDUCE_MUL:
2290+
ScalarOps = {{ISD::MUL, 2}};
2291+
IsAssociative = true;
2292+
// Integer multiply doesn't support packed types
2293+
PrefersShuffle = false;
2294+
break;
2295+
case ISD::VECREDUCE_UMAX:
2296+
ScalarOps = {{ISD::UMAX, 2}};
2297+
IsAssociative = true;
2298+
// Prefer max.u16x2 for v2i16
2299+
PrefersShuffle = EltTy == MVT::i16;
2300+
break;
2301+
case ISD::VECREDUCE_UMIN:
2302+
ScalarOps = {{ISD::UMIN, 2}};
2303+
IsAssociative = true;
2304+
// Prefer min.u16x2 for v2i16
2305+
PrefersShuffle = EltTy == MVT::i16;
2306+
break;
2307+
case ISD::VECREDUCE_SMAX:
2308+
ScalarOps = {{ISD::SMAX, 2}};
2309+
IsAssociative = true;
2310+
// Prefer max.s16x2 for v2i16
2311+
PrefersShuffle = EltTy == MVT::i16;
2312+
break;
2313+
case ISD::VECREDUCE_SMIN:
2314+
ScalarOps = {{ISD::SMIN, 2}};
2315+
IsAssociative = true;
2316+
// Prefer min.s16x2 for v2i16
2317+
PrefersShuffle = EltTy == MVT::i16;
2318+
break;
2319+
case ISD::VECREDUCE_AND:
2320+
ScalarOps = {{ISD::AND, 2}};
2321+
IsAssociative = true;
2322+
// Prefer and.b32 for v2i16.
2323+
PrefersShuffle = EltTy == MVT::i16;
2324+
break;
2325+
case ISD::VECREDUCE_OR:
2326+
ScalarOps = {{ISD::OR, 2}};
2327+
IsAssociative = true;
2328+
// Prefer or.b32 for v2i16.
2329+
PrefersShuffle = EltTy == MVT::i16;
2330+
break;
2331+
case ISD::VECREDUCE_XOR:
2332+
ScalarOps = {{ISD::XOR, 2}};
2333+
IsAssociative = true;
2334+
// Prefer xor.b32 for v2i16.
2335+
PrefersShuffle = EltTy == MVT::i16;
2336+
break;
2337+
default:
2338+
llvm_unreachable("unhandled vecreduce operation");
2339+
}
2340+
2341+
// We don't expect an accumulator for reassociative vector reduction ops.
2342+
assert((!IsAssociative || !Accumulator) && "unexpected accumulator");
2343+
2344+
// If shuffle reduction is preferred, leave it to SelectionDAG.
2345+
if (IsAssociative && PrefersShuffle)
2346+
return SDValue();
2347+
2348+
// Otherwise, handle the reduction here.
2349+
SmallVector<SDValue> Elements;
2350+
DAG.ExtractVectorElements(Vector, Elements);
2351+
2352+
// Lower to tree reduction.
2353+
if (IsAssociative)
2354+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2355+
2356+
// Lower to sequential reduction.
2357+
EVT VectorTy = Vector.getValueType();
2358+
const unsigned NumElts = VectorTy.getVectorNumElements();
2359+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2360+
// Try to reduce the remaining sequence as much as possible using the
2361+
// current operator.
2362+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2363+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2364+
2365+
if (!Accumulator) {
2366+
// Try to initialize the accumulator using the current operator.
2367+
if (I + DefaultGroupSize <= NumElts) {
2368+
Accumulator = DAG.getNode(
2369+
DefaultScalarOp, DL, EltTy,
2370+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2371+
I += DefaultGroupSize;
2372+
}
2373+
}
2374+
2375+
if (Accumulator) {
2376+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2377+
SmallVector<SDValue> Operands = {Accumulator};
2378+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2379+
Operands.push_back(Elements[I + K]);
2380+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2381+
}
2382+
}
2383+
}
2384+
2385+
return Accumulator;
2386+
}
2387+
21112388
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
21122389
// Handle bitcasting from v2i8 without hitting the default promotion
21132390
// strategy which goes through stack memory.
@@ -2940,6 +3217,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29403217
return LowerVECTOR_SHUFFLE(Op, DAG);
29413218
case ISD::CONCAT_VECTORS:
29423219
return LowerCONCAT_VECTORS(Op, DAG);
3220+
case ISD::VECREDUCE_FADD:
3221+
case ISD::VECREDUCE_FMUL:
3222+
case ISD::VECREDUCE_SEQ_FADD:
3223+
case ISD::VECREDUCE_SEQ_FMUL:
3224+
case ISD::VECREDUCE_FMAX:
3225+
case ISD::VECREDUCE_FMIN:
3226+
case ISD::VECREDUCE_FMAXIMUM:
3227+
case ISD::VECREDUCE_FMINIMUM:
3228+
case ISD::VECREDUCE_ADD:
3229+
case ISD::VECREDUCE_MUL:
3230+
case ISD::VECREDUCE_UMAX:
3231+
case ISD::VECREDUCE_UMIN:
3232+
case ISD::VECREDUCE_SMAX:
3233+
case ISD::VECREDUCE_SMIN:
3234+
case ISD::VECREDUCE_AND:
3235+
case ISD::VECREDUCE_OR:
3236+
case ISD::VECREDUCE_XOR:
3237+
return LowerVECREDUCE(Op, DAG);
29433238
case ISD::STORE:
29443239
return LowerSTORE(Op, DAG);
29453240
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
7373
UNPACK_VECTOR,
7474

7575
FCOPYSIGN,
76+
FMAXNUM3,
77+
FMINNUM3,
78+
FMAXIMUM3,
79+
FMINIMUM3,
80+
7681
DYNAMIC_STACKALLOC,
7782
STACKRESTORE,
7883
STACKSAVE,
@@ -300,6 +305,7 @@ class NVPTXTargetLowering : public TargetLowering {
300305

301306
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
302307
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
308+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
303309
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
304310
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
305311
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)