Skip to content

Commit 8cbda00

Browse files
committed
[NVPTX] add combiner rule for final packed op in reduction
1 parent 43be31e commit 8cbda00

File tree

2 files changed

+210
-244
lines changed

2 files changed

+210
-244
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
852852
if (STI.allowFP16Math() || STI.hasBF16Math())
853853
setTargetDAGCombine(ISD::SETCC);
854854

855+
// Combine reduction operations on packed types (e.g. fadd.f16x2) with vector
856+
// shuffles when one of their lanes is a no-op.
857+
if (STI.allowFP16Math() || STI.hasBF16Math())
858+
// already added above: FADD, ADD, AND
859+
setTargetDAGCombine({ISD::FMUL, ISD::FMINIMUM, ISD::FMAXIMUM, ISD::UMIN,
860+
ISD::UMAX, ISD::SMIN, ISD::SMAX, ISD::OR, ISD::XOR});
861+
855862
// Promote fp16 arithmetic if fp16 hardware isn't available or the
856863
// user passed --nvptx-no-fp16-math. The flag is useful because,
857864
// although sm_53+ GPUs have some sort of FP16 support in
@@ -5069,20 +5076,102 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
50695076
return PerformStoreCombineHelper(N, 2, 0);
50705077
}
50715078

5079+
/// For vector reductions, the final result needs to be a scalar. The default
5080+
/// expansion will use packed ops (ex. fadd.f16x2) even for the final operation.
5081+
/// This requires a packed operation where one of the lanes is undef.
5082+
///
5083+
/// ex: lowering of vecreduce_fadd(V) where V = v4f16<a b c d>
5084+
///
5085+
/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5086+
/// v2: v2f16 = vector_shuffle<1,u> v1, undef:v2f16 (== <b+d undef>)
5087+
/// v3: v2f16 = fadd reassoc v2, v1 (== <b+d+a+c undef>)
5088+
/// vR: f16 = extractelt v3, 1
5089+
///
5090+
/// We wish to replace vR, v3, and v2 with:
5091+
/// vR: f16 = fadd reassoc (extractelt v1, 1) (extractelt v1, 0)
5092+
///
5093+
/// ...so that we get:
5094+
/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5095+
/// s1: f16 = extractelt v1, 1
5096+
/// s2: f16 = extractelt v1, 0
5097+
/// vR: f16 = fadd reassoc s1, s2 (== a+c+b+d)
5098+
///
5099+
/// So for this example, this rule will replace v3 and v2, returning a vector
5100+
/// with the result in lane 0 and an undef in lane 1, which we expect will be
5101+
/// folded into the extractelt in vR.
5102+
static SDValue PerformPackedOpCombine(SDNode *N,
5103+
TargetLowering::DAGCombinerInfo &DCI) {
5104+
// Convert:
5105+
// (fop.x2 (vector_shuffle<i,u> A), B) -> ((fop A:i, B:0), undef)
5106+
// ...or...
5107+
// (fop.x2 (vector_shuffle<u,i> A), B) -> (undef, (fop A:i, B:1))
5108+
// ...where i is a valid index and u is poison.
5109+
const EVT VectorVT = N->getValueType(0);
5110+
if (!Isv2x16VT(VectorVT))
5111+
return SDValue();
5112+
5113+
SDLoc DL(N);
5114+
5115+
SDValue ShufOp = N->getOperand(0);
5116+
SDValue VectOp = N->getOperand(1);
5117+
bool Swapped = false;
5118+
5119+
// canonicalize shuffle to op0
5120+
if (VectOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
5121+
std::swap(ShufOp, VectOp);
5122+
Swapped = true;
5123+
}
5124+
5125+
if (ShufOp.getOpcode() != ISD::VECTOR_SHUFFLE)
5126+
return SDValue();
5127+
5128+
auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
5129+
int LiveLane; // exclusively live lane
5130+
for (LiveLane = 0; LiveLane < 2; ++LiveLane) {
5131+
// check if the current lane is live and the other lane is dead
5132+
if (ShuffleOp->getMaskElt(LiveLane) != PoisonMaskElem &&
5133+
ShuffleOp->getMaskElt(!LiveLane) == PoisonMaskElem)
5134+
break;
5135+
}
5136+
if (LiveLane == 2)
5137+
return SDValue();
5138+
5139+
int ElementIdx = ShuffleOp->getMaskElt(LiveLane);
5140+
const EVT ScalarVT = VectorVT.getScalarType();
5141+
SDValue Lanes[2] = {};
5142+
for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
5143+
if (LaneID == (unsigned)LiveLane) {
5144+
SDValue Operands[2] = {
5145+
DCI.DAG.getExtractVectorElt(DL, ScalarVT, ShufOp.getOperand(0),
5146+
ElementIdx),
5147+
DCI.DAG.getExtractVectorElt(DL, ScalarVT, VectOp, LiveLane)};
5148+
// preserve the order of operands
5149+
if (Swapped)
5150+
std::swap(Operands[0], Operands[1]);
5151+
LaneVal = DCI.DAG.getNode(N->getOpcode(), DL, ScalarVT, Operands);
5152+
} else {
5153+
LaneVal = DCI.DAG.getUNDEF(ScalarVT);
5154+
}
5155+
}
5156+
return DCI.DAG.getBuildVector(VectorVT, DL, Lanes);
5157+
}
5158+
50725159
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
50735160
///
50745161
static SDValue PerformADDCombine(SDNode *N,
50755162
TargetLowering::DAGCombinerInfo &DCI,
50765163
CodeGenOptLevel OptLevel) {
5077-
if (OptLevel == CodeGenOptLevel::None)
5078-
return SDValue();
5079-
50805164
SDValue N0 = N->getOperand(0);
50815165
SDValue N1 = N->getOperand(1);
50825166

50835167
// Skip non-integer, non-scalar case
50845168
EVT VT = N0.getValueType();
5085-
if (VT.isVector() || VT != MVT::i32)
5169+
if (VT.isVector())
5170+
return PerformPackedOpCombine(N, DCI);
5171+
if (VT != MVT::i32)
5172+
return SDValue();
5173+
5174+
if (OptLevel == CodeGenOptLevel::None)
50865175
return SDValue();
50875176

50885177
// First try with the default operand order.
@@ -5102,7 +5191,10 @@ static SDValue PerformFADDCombine(SDNode *N,
51025191
SDValue N1 = N->getOperand(1);
51035192

51045193
EVT VT = N0.getValueType();
5105-
if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
5194+
if (VT.isVector())
5195+
return PerformPackedOpCombine(N, DCI);
5196+
5197+
if (!(VT == MVT::f32 || VT == MVT::f64))
51065198
return SDValue();
51075199

51085200
// First try with the default operand order.
@@ -5205,7 +5297,7 @@ static SDValue PerformANDCombine(SDNode *N,
52055297
DCI.CombineTo(N, Val, AddTo);
52065298
}
52075299

5208-
return SDValue();
5300+
return PerformPackedOpCombine(N, DCI);
52095301
}
52105302

52115303
static SDValue PerformREMCombine(SDNode *N,
@@ -5686,6 +5778,16 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56865778
return PerformADDCombine(N, DCI, OptLevel);
56875779
case ISD::FADD:
56885780
return PerformFADDCombine(N, DCI, OptLevel);
5781+
case ISD::FMUL:
5782+
case ISD::FMINNUM:
5783+
case ISD::FMAXIMUM:
5784+
case ISD::UMIN:
5785+
case ISD::UMAX:
5786+
case ISD::SMIN:
5787+
case ISD::SMAX:
5788+
case ISD::OR:
5789+
case ISD::XOR:
5790+
return PerformPackedOpCombine(N, DCI);
56895791
case ISD::MUL:
56905792
return PerformMULCombine(N, DCI, OptLevel);
56915793
case ISD::SHL:

0 commit comments

Comments
 (0)