@@ -852,6 +852,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
852
852
if (STI.allowFP16Math () || STI.hasBF16Math ())
853
853
setTargetDAGCombine (ISD::SETCC);
854
854
855
+ // Combine reduction operations on packed types (e.g. fadd.f16x2) with vector
856
+ // shuffles when one of their lanes is a no-op.
857
+ if (STI.allowFP16Math () || STI.hasBF16Math ())
858
+ // already added above: FADD, ADD, AND
859
+ setTargetDAGCombine ({ISD::FMUL, ISD::FMINIMUM, ISD::FMAXIMUM, ISD::UMIN,
860
+ ISD::UMAX, ISD::SMIN, ISD::SMAX, ISD::OR, ISD::XOR});
861
+
855
862
// Promote fp16 arithmetic if fp16 hardware isn't available or the
856
863
// user passed --nvptx-no-fp16-math. The flag is useful because,
857
864
// although sm_53+ GPUs have some sort of FP16 support in
@@ -5069,20 +5076,102 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
5069
5076
return PerformStoreCombineHelper (N, 2 , 0 );
5070
5077
}
5071
5078
5079
+ // / For vector reductions, the final result needs to be a scalar. The default
5080
+ // / expansion will use packed ops (ex. fadd.f16x2) even for the final operation.
5081
+ // / This requires a packed operation where one of the lanes is undef.
5082
+ // /
5083
+ // / ex: lowering of vecreduce_fadd(V) where V = v4f16<a b c d>
5084
+ // /
5085
+ // / v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5086
+ // / v2: v2f16 = vector_shuffle<1,u> v1, undef:v2f16 (== <b+d undef>)
5087
+ // / v3: v2f16 = fadd reassoc v2, v1 (== <b+d+a+c undef>)
5088
+ // / vR: f16 = extractelt v3, 1
5089
+ // /
5090
+ // / We wish to replace vR, v3, and v2 with:
5091
+ // / vR: f16 = fadd reassoc (extractelt v1, 1) (extractelt v1, 0)
5092
+ // /
5093
+ // / ...so that we get:
5094
+ // / v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5095
+ // / s1: f16 = extractelt v1, 1
5096
+ // / s2: f16 = extractelt v1, 0
5097
+ // / vR: f16 = fadd reassoc s1, s2 (== a+c+b+d)
5098
+ // /
5099
+ // / So for this example, this rule will replace v3 and v2, returning a vector
5100
+ // / with the result in lane 0 and an undef in lane 1, which we expect will be
5101
+ // / folded into the extractelt in vR.
5102
+ static SDValue PerformPackedOpCombine (SDNode *N,
5103
+ TargetLowering::DAGCombinerInfo &DCI) {
5104
+ // Convert:
5105
+ // (fop.x2 (vector_shuffle<i,u> A), B) -> ((fop A:i, B:0), undef)
5106
+ // ...or...
5107
+ // (fop.x2 (vector_shuffle<u,i> A), B) -> (undef, (fop A:i, B:1))
5108
+ // ...where i is a valid index and u is poison.
5109
+ const EVT VectorVT = N->getValueType (0 );
5110
+ if (!Isv2x16VT (VectorVT))
5111
+ return SDValue ();
5112
+
5113
+ SDLoc DL (N);
5114
+
5115
+ SDValue ShufOp = N->getOperand (0 );
5116
+ SDValue VectOp = N->getOperand (1 );
5117
+ bool Swapped = false ;
5118
+
5119
+ // canonicalize shuffle to op0
5120
+ if (VectOp.getOpcode () == ISD::VECTOR_SHUFFLE) {
5121
+ std::swap (ShufOp, VectOp);
5122
+ Swapped = true ;
5123
+ }
5124
+
5125
+ if (ShufOp.getOpcode () != ISD::VECTOR_SHUFFLE)
5126
+ return SDValue ();
5127
+
5128
+ auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
5129
+ int LiveLane; // exclusively live lane
5130
+ for (LiveLane = 0 ; LiveLane < 2 ; ++LiveLane) {
5131
+ // check if the current lane is live and the other lane is dead
5132
+ if (ShuffleOp->getMaskElt (LiveLane) != PoisonMaskElem &&
5133
+ ShuffleOp->getMaskElt (!LiveLane) == PoisonMaskElem)
5134
+ break ;
5135
+ }
5136
+ if (LiveLane == 2 )
5137
+ return SDValue ();
5138
+
5139
+ int ElementIdx = ShuffleOp->getMaskElt (LiveLane);
5140
+ const EVT ScalarVT = VectorVT.getScalarType ();
5141
+ SDValue Lanes[2 ] = {};
5142
+ for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
5143
+ if (LaneID == (unsigned )LiveLane) {
5144
+ SDValue Operands[2 ] = {
5145
+ DCI.DAG .getExtractVectorElt (DL, ScalarVT, ShufOp.getOperand (0 ),
5146
+ ElementIdx),
5147
+ DCI.DAG .getExtractVectorElt (DL, ScalarVT, VectOp, LiveLane)};
5148
+ // preserve the order of operands
5149
+ if (Swapped)
5150
+ std::swap (Operands[0 ], Operands[1 ]);
5151
+ LaneVal = DCI.DAG .getNode (N->getOpcode (), DL, ScalarVT, Operands);
5152
+ } else {
5153
+ LaneVal = DCI.DAG .getUNDEF (ScalarVT);
5154
+ }
5155
+ }
5156
+ return DCI.DAG .getBuildVector (VectorVT, DL, Lanes);
5157
+ }
5158
+
5072
5159
// / PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
5073
5160
// /
5074
5161
static SDValue PerformADDCombine (SDNode *N,
5075
5162
TargetLowering::DAGCombinerInfo &DCI,
5076
5163
CodeGenOptLevel OptLevel) {
5077
- if (OptLevel == CodeGenOptLevel::None)
5078
- return SDValue ();
5079
-
5080
5164
SDValue N0 = N->getOperand (0 );
5081
5165
SDValue N1 = N->getOperand (1 );
5082
5166
5083
5167
// Skip non-integer, non-scalar case
5084
5168
EVT VT = N0.getValueType ();
5085
- if (VT.isVector () || VT != MVT::i32 )
5169
+ if (VT.isVector ())
5170
+ return PerformPackedOpCombine (N, DCI);
5171
+ if (VT != MVT::i32 )
5172
+ return SDValue ();
5173
+
5174
+ if (OptLevel == CodeGenOptLevel::None)
5086
5175
return SDValue ();
5087
5176
5088
5177
// First try with the default operand order.
@@ -5102,7 +5191,10 @@ static SDValue PerformFADDCombine(SDNode *N,
5102
5191
SDValue N1 = N->getOperand (1 );
5103
5192
5104
5193
EVT VT = N0.getValueType ();
5105
- if (VT.isVector () || !(VT == MVT::f32 || VT == MVT::f64 ))
5194
+ if (VT.isVector ())
5195
+ return PerformPackedOpCombine (N, DCI);
5196
+
5197
+ if (!(VT == MVT::f32 || VT == MVT::f64 ))
5106
5198
return SDValue ();
5107
5199
5108
5200
// First try with the default operand order.
@@ -5205,7 +5297,7 @@ static SDValue PerformANDCombine(SDNode *N,
5205
5297
DCI.CombineTo (N, Val, AddTo);
5206
5298
}
5207
5299
5208
- return SDValue ( );
5300
+ return PerformPackedOpCombine (N, DCI );
5209
5301
}
5210
5302
5211
5303
static SDValue PerformREMCombine (SDNode *N,
@@ -5686,6 +5778,16 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5686
5778
return PerformADDCombine (N, DCI, OptLevel);
5687
5779
case ISD::FADD:
5688
5780
return PerformFADDCombine (N, DCI, OptLevel);
5781
+ case ISD::FMUL:
5782
+ case ISD::FMINNUM:
5783
+ case ISD::FMAXIMUM:
5784
+ case ISD::UMIN:
5785
+ case ISD::UMAX:
5786
+ case ISD::SMIN:
5787
+ case ISD::SMAX:
5788
+ case ISD::OR:
5789
+ case ISD::XOR:
5790
+ return PerformPackedOpCombine (N, DCI);
5689
5791
case ISD::MUL:
5690
5792
return PerformMULCombine (N, DCI, OptLevel);
5691
5793
case ISD::SHL:
0 commit comments