Skip to content

Commit 36e8acd

Browse files
committed
[NVPTX] lower VECREDUCE intrinsics to tree reduction
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in PTX 8.8. This method of tree reduction has a few benefits over the default in DAGTypeLegalizer: - The default shuffle reduction progressively halves and partially reduces the vector down until we reach a single element. This produces a sequence of operations that combine disparate elements of the vector. For example, `vecreduce_fadd <4 x f32><a b c d>` will give `(a + c) + (b + d)`, whereas the tree reduction produces (a + b) + (c + d) by grouping nearby elements together first. Both use the same number of registers, but the shuffle reduction has longer live ranges. The same example is graphed below. Note we hold onto 3 registers for 2 cycles in the shuffle reduction and 1 cycle in tree reduction. (shuffle reduction) PTX: %r1 = add.f32 a, c %r2 = add.f32 b, d %r3 = add.f32 %r1, %r3 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | | %r1 = add.f32 a, c | %r2 = add.f32 b, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | a b c [3R] | b d %r1 [3R] | %r1 %r2 [2R] | %r3 [1R] | (tree reduction) PTX: %r1 = add.f32 a, b %r2 = add.f32 c, d %r3 = add.f32 %r1, %r2 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | %r1 = add.f32 a, b | | %r2 = add.f32 c, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | c %r1 [2R] | c %r1 d [3R] | %r1 %r2 [2R] | %r3 [1R] | - The shuffle reduction cannot easily support fmax3/fmin3 because it progressively halves the input vector. - Faster compile time. Happens in one pass over the intrinsic, rather than O(N) passes if iteratively splitting the vector operands.
1 parent 351303c commit 36e8acd

File tree

5 files changed

+1125
-706
lines changed

5 files changed

+1125
-706
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
852852
if (STI.allowFP16Math() || STI.hasBF16Math())
853853
setTargetDAGCombine(ISD::SETCC);
854854

855+
// Vector reduction operations. These may be turned into sequential, shuffle,
856+
// or tree reductions depending on what instructions are available for each
857+
// type.
858+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
859+
MVT EltVT = VT.getVectorElementType();
860+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
861+
EltVT == MVT::f64) {
862+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
863+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
864+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
865+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
866+
VT, Custom);
867+
} else if (EltVT.isScalarInteger()) {
868+
setOperationAction(
869+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
870+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
871+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
872+
VT, Custom);
873+
}
874+
}
875+
855876
// Promote fp16 arithmetic if fp16 hardware isn't available or the
856877
// user passed --nvptx-no-fp16-math. The flag is useful because,
857878
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1109,6 +1130,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11091130
MAKE_CASE(NVPTXISD::BFI)
11101131
MAKE_CASE(NVPTXISD::PRMT)
11111132
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1133+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1134+
MAKE_CASE(NVPTXISD::FMINNUM3)
1135+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1136+
MAKE_CASE(NVPTXISD::FMINIMUM3)
11121137
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
11131138
MAKE_CASE(NVPTXISD::STACKRESTORE)
11141139
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2108,6 +2133,259 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21082133
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21092134
}
21102135

2136+
/// A generic routine for constructing a tree reduction on a vector operand.
2137+
/// This method groups elements bottom-up, progressively building each level.
2138+
/// Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2139+
/// adjacent elements are combined first, leading to shorter live ranges. This
2140+
/// approach makes the most sense if the shuffle reduction would use the same
2141+
/// amount of registers.
2142+
///
2143+
/// The flags on the original reduction operation will be propagated to
2144+
/// each scalar operation.
2145+
static SDValue BuildTreeReduction(
2146+
const SmallVector<SDValue> &Elements, EVT EltTy,
2147+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2148+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2149+
// Build the reduction tree at each level, starting with all the elements.
2150+
SmallVector<SDValue> Level = Elements;
2151+
2152+
unsigned OpIdx = 0;
2153+
while (Level.size() > 1) {
2154+
// Try to reduce this level using the current operator.
2155+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2156+
2157+
// Build the next level by partially reducing all elements.
2158+
SmallVector<SDValue> ReducedLevel;
2159+
unsigned I = 0, E = Level.size();
2160+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2161+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2162+
ReducedLevel.push_back(DAG.getNode(
2163+
DefaultScalarOp, DL, EltTy,
2164+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2165+
}
2166+
2167+
if (I < E) {
2168+
// Handle leftover elements.
2169+
2170+
if (ReducedLevel.empty()) {
2171+
// We didn't reduce anything at this level. We need to pick a smaller
2172+
// operator.
2173+
++OpIdx;
2174+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2175+
continue;
2176+
}
2177+
2178+
// We reduced some things but there's still more left, meaning the
2179+
// operator's number of inputs doesn't evenly divide this level size. Move
2180+
// these elements to the next level.
2181+
for (; I < E; ++I)
2182+
ReducedLevel.push_back(Level[I]);
2183+
}
2184+
2185+
// Process the next level.
2186+
Level = ReducedLevel;
2187+
}
2188+
2189+
return *Level.begin();
2190+
}
2191+
2192+
/// Lower reductions to either a sequence of operations or a tree if
2193+
/// reassociations are allowed. This method will use larger operations like
2194+
/// max3/min3 when the target supports them.
2195+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2196+
SelectionDAG &DAG) const {
2197+
SDLoc DL(Op);
2198+
const SDNodeFlags Flags = Op->getFlags();
2199+
SDValue Vector;
2200+
SDValue Accumulator;
2201+
2202+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2203+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2204+
// special case with accumulator as first arg
2205+
Accumulator = Op.getOperand(0);
2206+
Vector = Op.getOperand(1);
2207+
} else {
2208+
// default case
2209+
Vector = Op.getOperand(0);
2210+
}
2211+
2212+
EVT EltTy = Vector.getValueType().getVectorElementType();
2213+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2214+
STI.getPTXVersion() >= 88;
2215+
2216+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2217+
// number of inputs they take.
2218+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2219+
2220+
// Whether we can lower to scalar operations in an arbitrary order.
2221+
bool IsAssociative = allowUnsafeFPMath(DAG.getMachineFunction());
2222+
2223+
// Whether the data type and operation can be represented with fewer ops and
2224+
// registers in a shuffle reduction.
2225+
bool PrefersShuffle;
2226+
2227+
switch (Op->getOpcode()) {
2228+
case ISD::VECREDUCE_FADD:
2229+
case ISD::VECREDUCE_SEQ_FADD:
2230+
ScalarOps = {{ISD::FADD, 2}};
2231+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FADD;
2232+
// Prefer add.{,b}f16x2 for v2{,b}f16
2233+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2234+
break;
2235+
case ISD::VECREDUCE_FMUL:
2236+
case ISD::VECREDUCE_SEQ_FMUL:
2237+
ScalarOps = {{ISD::FMUL, 2}};
2238+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FMUL;
2239+
// Prefer mul.{,b}f16x2 for v2{,b}f16
2240+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2241+
break;
2242+
case ISD::VECREDUCE_FMAX:
2243+
if (CanUseMinMax3)
2244+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2245+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2246+
// Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2247+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2248+
IsAssociative |= Flags.hasAllowReassociation();
2249+
PrefersShuffle = false;
2250+
break;
2251+
case ISD::VECREDUCE_FMIN:
2252+
if (CanUseMinMax3)
2253+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2254+
ScalarOps.push_back({ISD::FMINNUM, 2});
2255+
// Definition of minNum in IEEE 754 2008 is non-associative due to handling
2256+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2257+
IsAssociative |= Flags.hasAllowReassociation();
2258+
PrefersShuffle = false;
2259+
break;
2260+
case ISD::VECREDUCE_FMAXIMUM:
2261+
if (CanUseMinMax3) {
2262+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2263+
// Can't use fmax3 in shuffle reduction
2264+
PrefersShuffle = false;
2265+
} else {
2266+
// Prefer max.{,b}f16x2 for v2{,b}f16
2267+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2268+
}
2269+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2270+
IsAssociative = true;
2271+
break;
2272+
case ISD::VECREDUCE_FMINIMUM:
2273+
if (CanUseMinMax3) {
2274+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2275+
// Can't use fmin3 in shuffle reduction
2276+
PrefersShuffle = false;
2277+
} else {
2278+
// Prefer min.{,b}f16x2 for v2{,b}f16
2279+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2280+
}
2281+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2282+
IsAssociative = true;
2283+
break;
2284+
case ISD::VECREDUCE_ADD:
2285+
ScalarOps = {{ISD::ADD, 2}};
2286+
IsAssociative = true;
2287+
// Prefer add.{s,u}16x2 for v2i16
2288+
PrefersShuffle = EltTy == MVT::i16;
2289+
break;
2290+
case ISD::VECREDUCE_MUL:
2291+
ScalarOps = {{ISD::MUL, 2}};
2292+
IsAssociative = true;
2293+
// Integer multiply doesn't support packed types
2294+
PrefersShuffle = false;
2295+
break;
2296+
case ISD::VECREDUCE_UMAX:
2297+
ScalarOps = {{ISD::UMAX, 2}};
2298+
IsAssociative = true;
2299+
// Prefer max.u16x2 for v2i16
2300+
PrefersShuffle = EltTy == MVT::i16;
2301+
break;
2302+
case ISD::VECREDUCE_UMIN:
2303+
ScalarOps = {{ISD::UMIN, 2}};
2304+
IsAssociative = true;
2305+
// Prefer min.u16x2 for v2i16
2306+
PrefersShuffle = EltTy == MVT::i16;
2307+
break;
2308+
case ISD::VECREDUCE_SMAX:
2309+
ScalarOps = {{ISD::SMAX, 2}};
2310+
IsAssociative = true;
2311+
// Prefer max.s16x2 for v2i16
2312+
PrefersShuffle = EltTy == MVT::i16;
2313+
break;
2314+
case ISD::VECREDUCE_SMIN:
2315+
ScalarOps = {{ISD::SMIN, 2}};
2316+
IsAssociative = true;
2317+
// Prefer min.s16x2 for v2i16
2318+
PrefersShuffle = EltTy == MVT::i16;
2319+
break;
2320+
case ISD::VECREDUCE_AND:
2321+
ScalarOps = {{ISD::AND, 2}};
2322+
IsAssociative = true;
2323+
// Prefer and.b32 for v2i16.
2324+
PrefersShuffle = EltTy == MVT::i16;
2325+
break;
2326+
case ISD::VECREDUCE_OR:
2327+
ScalarOps = {{ISD::OR, 2}};
2328+
IsAssociative = true;
2329+
// Prefer or.b32 for v2i16.
2330+
PrefersShuffle = EltTy == MVT::i16;
2331+
break;
2332+
case ISD::VECREDUCE_XOR:
2333+
ScalarOps = {{ISD::XOR, 2}};
2334+
IsAssociative = true;
2335+
// Prefer xor.b32 for v2i16.
2336+
PrefersShuffle = EltTy == MVT::i16;
2337+
break;
2338+
default:
2339+
llvm_unreachable("unhandled vecreduce operation");
2340+
}
2341+
2342+
// We don't expect an accumulator for reassociative vector reduction ops.
2343+
assert((!IsAssociative || !Accumulator) && "unexpected accumulator");
2344+
2345+
// If shuffle reduction is preferred, leave it to SelectionDAG.
2346+
if (IsAssociative && PrefersShuffle)
2347+
return SDValue();
2348+
2349+
// Otherwise, handle the reduction here.
2350+
SmallVector<SDValue> Elements;
2351+
DAG.ExtractVectorElements(Vector, Elements);
2352+
2353+
// Lower to tree reduction.
2354+
if (IsAssociative)
2355+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2356+
2357+
// Lower to sequential reduction.
2358+
EVT VectorTy = Vector.getValueType();
2359+
const unsigned NumElts = VectorTy.getVectorNumElements();
2360+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2361+
// Try to reduce the remaining sequence as much as possible using the
2362+
// current operator.
2363+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2364+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2365+
2366+
if (!Accumulator) {
2367+
// Try to initialize the accumulator using the current operator.
2368+
if (I + DefaultGroupSize <= NumElts) {
2369+
Accumulator = DAG.getNode(
2370+
DefaultScalarOp, DL, EltTy,
2371+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2372+
I += DefaultGroupSize;
2373+
}
2374+
}
2375+
2376+
if (Accumulator) {
2377+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2378+
SmallVector<SDValue> Operands = {Accumulator};
2379+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2380+
Operands.push_back(Elements[I + K]);
2381+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2382+
}
2383+
}
2384+
}
2385+
2386+
return Accumulator;
2387+
}
2388+
21112389
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
21122390
// Handle bitcasting from v2i8 without hitting the default promotion
21132391
// strategy which goes through stack memory.
@@ -2940,6 +3218,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29403218
return LowerVECTOR_SHUFFLE(Op, DAG);
29413219
case ISD::CONCAT_VECTORS:
29423220
return LowerCONCAT_VECTORS(Op, DAG);
3221+
case ISD::VECREDUCE_FADD:
3222+
case ISD::VECREDUCE_FMUL:
3223+
case ISD::VECREDUCE_SEQ_FADD:
3224+
case ISD::VECREDUCE_SEQ_FMUL:
3225+
case ISD::VECREDUCE_FMAX:
3226+
case ISD::VECREDUCE_FMIN:
3227+
case ISD::VECREDUCE_FMAXIMUM:
3228+
case ISD::VECREDUCE_FMINIMUM:
3229+
case ISD::VECREDUCE_ADD:
3230+
case ISD::VECREDUCE_MUL:
3231+
case ISD::VECREDUCE_UMAX:
3232+
case ISD::VECREDUCE_UMIN:
3233+
case ISD::VECREDUCE_SMAX:
3234+
case ISD::VECREDUCE_SMIN:
3235+
case ISD::VECREDUCE_AND:
3236+
case ISD::VECREDUCE_OR:
3237+
case ISD::VECREDUCE_XOR:
3238+
return LowerVECREDUCE(Op, DAG);
29433239
case ISD::STORE:
29443240
return LowerSTORE(Op, DAG);
29453241
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
7373
UNPACK_VECTOR,
7474

7575
FCOPYSIGN,
76+
FMAXNUM3,
77+
FMINNUM3,
78+
FMAXIMUM3,
79+
FMINIMUM3,
80+
7681
DYNAMIC_STACKALLOC,
7782
STACKRESTORE,
7883
STACKSAVE,
@@ -300,6 +305,7 @@ class NVPTXTargetLowering : public TargetLowering {
300305

301306
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
302307
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
308+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
303309
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
304310
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
305311
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)