Skip to content

Commit f879526

Browse files
committed
[NVPTX] add combiner rule for final packed op in reduction
1 parent 5a6a4b6 commit f879526

File tree

2 files changed

+210
-244
lines changed

2 files changed

+210
-244
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
843843
if (STI.allowFP16Math() || STI.hasBF16Math())
844844
setTargetDAGCombine(ISD::SETCC);
845845

846+
// Combine reduction operations on packed types (e.g. fadd.f16x2) with vector
847+
// shuffles when one of their lanes is a no-op.
848+
if (STI.allowFP16Math() || STI.hasBF16Math())
849+
// already added above: FADD, ADD, AND
850+
setTargetDAGCombine({ISD::FMUL, ISD::FMINIMUM, ISD::FMAXIMUM, ISD::UMIN,
851+
ISD::UMAX, ISD::SMIN, ISD::SMAX, ISD::OR, ISD::XOR});
852+
846853
// Promote fp16 arithmetic if fp16 hardware isn't available or the
847854
// user passed --nvptx-no-fp16-math. The flag is useful because,
848855
// although sm_53+ GPUs have some sort of FP16 support in
@@ -5059,20 +5066,102 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
50595066
return PerformStoreCombineHelper(N, 2, 0);
50605067
}
50615068

5069+
/// For vector reductions, the final result needs to be a scalar. The default
5070+
/// expansion will use packed ops (ex. fadd.f16x2) even for the final operation.
5071+
/// This requires a packed operation where one of the lanes is undef.
5072+
///
5073+
/// ex: lowering of vecreduce_fadd(V) where V = v4f16<a b c d>
5074+
///
5075+
/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5076+
/// v2: v2f16 = vector_shuffle<1,u> v1, undef:v2f16 (== <b+d undef>)
5077+
/// v3: v2f16 = fadd reassoc v2, v1 (== <b+d+a+c undef>)
5078+
/// vR: f16 = extractelt v3, 1
5079+
///
5080+
/// We wish to replace vR, v3, and v2 with:
5081+
/// vR: f16 = fadd reassoc (extractelt v1, 1) (extractelt v1, 0)
5082+
///
5083+
/// ...so that we get:
5084+
/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5085+
/// s1: f16 = extractelt v1, 1
5086+
/// s2: f16 = extractelt v1, 0
5087+
/// vR: f16 = fadd reassoc s1, s2 (== a+c+b+d)
5088+
///
5089+
/// So for this example, this rule will replace v3 and v2, returning a vector
5090+
/// with the result in lane 0 and an undef in lane 1, which we expect will be
5091+
/// folded into the extractelt in vR.
5092+
static SDValue PerformPackedOpCombine(SDNode *N,
5093+
TargetLowering::DAGCombinerInfo &DCI) {
5094+
// Convert:
5095+
// (fop.x2 (vector_shuffle<i,u> A), B) -> ((fop A:i, B:0), undef)
5096+
// ...or...
5097+
// (fop.x2 (vector_shuffle<u,i> A), B) -> (undef, (fop A:i, B:1))
5098+
// ...where i is a valid index and u is poison.
5099+
const EVT VectorVT = N->getValueType(0);
5100+
if (!Isv2x16VT(VectorVT))
5101+
return SDValue();
5102+
5103+
SDLoc DL(N);
5104+
5105+
SDValue ShufOp = N->getOperand(0);
5106+
SDValue VectOp = N->getOperand(1);
5107+
bool Swapped = false;
5108+
5109+
// canonicalize shuffle to op0
5110+
if (VectOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
5111+
std::swap(ShufOp, VectOp);
5112+
Swapped = true;
5113+
}
5114+
5115+
if (ShufOp.getOpcode() != ISD::VECTOR_SHUFFLE)
5116+
return SDValue();
5117+
5118+
auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
5119+
int LiveLane; // exclusively live lane
5120+
for (LiveLane = 0; LiveLane < 2; ++LiveLane) {
5121+
// check if the current lane is live and the other lane is dead
5122+
if (ShuffleOp->getMaskElt(LiveLane) != PoisonMaskElem &&
5123+
ShuffleOp->getMaskElt(!LiveLane) == PoisonMaskElem)
5124+
break;
5125+
}
5126+
if (LiveLane == 2)
5127+
return SDValue();
5128+
5129+
int ElementIdx = ShuffleOp->getMaskElt(LiveLane);
5130+
const EVT ScalarVT = VectorVT.getScalarType();
5131+
SDValue Lanes[2] = {};
5132+
for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
5133+
if (LaneID == (unsigned)LiveLane) {
5134+
SDValue Operands[2] = {
5135+
DCI.DAG.getExtractVectorElt(DL, ScalarVT, ShufOp.getOperand(0),
5136+
ElementIdx),
5137+
DCI.DAG.getExtractVectorElt(DL, ScalarVT, VectOp, LiveLane)};
5138+
// preserve the order of operands
5139+
if (Swapped)
5140+
std::swap(Operands[0], Operands[1]);
5141+
LaneVal = DCI.DAG.getNode(N->getOpcode(), DL, ScalarVT, Operands);
5142+
} else {
5143+
LaneVal = DCI.DAG.getUNDEF(ScalarVT);
5144+
}
5145+
}
5146+
return DCI.DAG.getBuildVector(VectorVT, DL, Lanes);
5147+
}
5148+
50625149
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
50635150
///
50645151
static SDValue PerformADDCombine(SDNode *N,
50655152
TargetLowering::DAGCombinerInfo &DCI,
50665153
CodeGenOptLevel OptLevel) {
5067-
if (OptLevel == CodeGenOptLevel::None)
5068-
return SDValue();
5069-
50705154
SDValue N0 = N->getOperand(0);
50715155
SDValue N1 = N->getOperand(1);
50725156

50735157
// Skip non-integer, non-scalar case
50745158
EVT VT = N0.getValueType();
5075-
if (VT.isVector() || VT != MVT::i32)
5159+
if (VT.isVector())
5160+
return PerformPackedOpCombine(N, DCI);
5161+
if (VT != MVT::i32)
5162+
return SDValue();
5163+
5164+
if (OptLevel == CodeGenOptLevel::None)
50765165
return SDValue();
50775166

50785167
// First try with the default operand order.
@@ -5092,7 +5181,10 @@ static SDValue PerformFADDCombine(SDNode *N,
50925181
SDValue N1 = N->getOperand(1);
50935182

50945183
EVT VT = N0.getValueType();
5095-
if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
5184+
if (VT.isVector())
5185+
return PerformPackedOpCombine(N, DCI);
5186+
5187+
if (!(VT == MVT::f32 || VT == MVT::f64))
50965188
return SDValue();
50975189

50985190
// First try with the default operand order.
@@ -5195,7 +5287,7 @@ static SDValue PerformANDCombine(SDNode *N,
51955287
DCI.CombineTo(N, Val, AddTo);
51965288
}
51975289

5198-
return SDValue();
5290+
return PerformPackedOpCombine(N, DCI);
51995291
}
52005292

52015293
static SDValue PerformREMCombine(SDNode *N,
@@ -5676,6 +5768,16 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56765768
return PerformADDCombine(N, DCI, OptLevel);
56775769
case ISD::FADD:
56785770
return PerformFADDCombine(N, DCI, OptLevel);
5771+
case ISD::FMUL:
5772+
case ISD::FMINNUM:
5773+
case ISD::FMAXIMUM:
5774+
case ISD::UMIN:
5775+
case ISD::UMAX:
5776+
case ISD::SMIN:
5777+
case ISD::SMAX:
5778+
case ISD::OR:
5779+
case ISD::XOR:
5780+
return PerformPackedOpCombine(N, DCI);
56795781
case ISD::MUL:
56805782
return PerformMULCombine(N, DCI, OptLevel);
56815783
case ISD::SHL:

0 commit comments

Comments
 (0)