@@ -843,6 +843,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
843
843
if (STI.allowFP16Math () || STI.hasBF16Math ())
844
844
setTargetDAGCombine (ISD::SETCC);
845
845
846
+ // Combine reduction operations on packed types (e.g. fadd.f16x2) with vector
847
+ // shuffles when one of their lanes is a no-op.
848
+ if (STI.allowFP16Math () || STI.hasBF16Math ())
849
+ // already added above: FADD, ADD, AND
850
+ setTargetDAGCombine ({ISD::FMUL, ISD::FMINIMUM, ISD::FMAXIMUM, ISD::UMIN,
851
+ ISD::UMAX, ISD::SMIN, ISD::SMAX, ISD::OR, ISD::XOR});
852
+
846
853
// Promote fp16 arithmetic if fp16 hardware isn't available or the
847
854
// user passed --nvptx-no-fp16-math. The flag is useful because,
848
855
// although sm_53+ GPUs have some sort of FP16 support in
@@ -5059,20 +5066,102 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
5059
5066
return PerformStoreCombineHelper (N, 2 , 0 );
5060
5067
}
5061
5068
5069
+ // / For vector reductions, the final result needs to be a scalar. The default
5070
+ // / expansion will use packed ops (ex. fadd.f16x2) even for the final operation.
5071
+ // / This requires a packed operation where one of the lanes is undef.
5072
+ // /
5073
+ // / ex: lowering of vecreduce_fadd(V) where V = v4f16<a b c d>
5074
+ // /
5075
+ // / v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5076
+ // / v2: v2f16 = vector_shuffle<1,u> v1, undef:v2f16 (== <b+d undef>)
5077
+ // / v3: v2f16 = fadd reassoc v2, v1 (== <b+d+a+c undef>)
5078
+ // / vR: f16 = extractelt v3, 1
5079
+ // /
5080
+ // / We wish to replace vR, v3, and v2 with:
5081
+ // / vR: f16 = fadd reassoc (extractelt v1, 1) (extractelt v1, 0)
5082
+ // /
5083
+ // / ...so that we get:
5084
+ // / v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
5085
+ // / s1: f16 = extractelt v1, 1
5086
+ // / s2: f16 = extractelt v1, 0
5087
+ // / vR: f16 = fadd reassoc s1, s2 (== a+c+b+d)
5088
+ // /
5089
+ // / So for this example, this rule will replace v3 and v2, returning a vector
5090
+ // / with the result in lane 0 and an undef in lane 1, which we expect will be
5091
+ // / folded into the extractelt in vR.
5092
+ static SDValue PerformPackedOpCombine (SDNode *N,
5093
+ TargetLowering::DAGCombinerInfo &DCI) {
5094
+ // Convert:
5095
+ // (fop.x2 (vector_shuffle<i,u> A), B) -> ((fop A:i, B:0), undef)
5096
+ // ...or...
5097
+ // (fop.x2 (vector_shuffle<u,i> A), B) -> (undef, (fop A:i, B:1))
5098
+ // ...where i is a valid index and u is poison.
5099
+ const EVT VectorVT = N->getValueType (0 );
5100
+ if (!Isv2x16VT (VectorVT))
5101
+ return SDValue ();
5102
+
5103
+ SDLoc DL (N);
5104
+
5105
+ SDValue ShufOp = N->getOperand (0 );
5106
+ SDValue VectOp = N->getOperand (1 );
5107
+ bool Swapped = false ;
5108
+
5109
+ // canonicalize shuffle to op0
5110
+ if (VectOp.getOpcode () == ISD::VECTOR_SHUFFLE) {
5111
+ std::swap (ShufOp, VectOp);
5112
+ Swapped = true ;
5113
+ }
5114
+
5115
+ if (ShufOp.getOpcode () != ISD::VECTOR_SHUFFLE)
5116
+ return SDValue ();
5117
+
5118
+ auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
5119
+ int LiveLane; // exclusively live lane
5120
+ for (LiveLane = 0 ; LiveLane < 2 ; ++LiveLane) {
5121
+ // check if the current lane is live and the other lane is dead
5122
+ if (ShuffleOp->getMaskElt (LiveLane) != PoisonMaskElem &&
5123
+ ShuffleOp->getMaskElt (!LiveLane) == PoisonMaskElem)
5124
+ break ;
5125
+ }
5126
+ if (LiveLane == 2 )
5127
+ return SDValue ();
5128
+
5129
+ int ElementIdx = ShuffleOp->getMaskElt (LiveLane);
5130
+ const EVT ScalarVT = VectorVT.getScalarType ();
5131
+ SDValue Lanes[2 ] = {};
5132
+ for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
5133
+ if (LaneID == (unsigned )LiveLane) {
5134
+ SDValue Operands[2 ] = {
5135
+ DCI.DAG .getExtractVectorElt (DL, ScalarVT, ShufOp.getOperand (0 ),
5136
+ ElementIdx),
5137
+ DCI.DAG .getExtractVectorElt (DL, ScalarVT, VectOp, LiveLane)};
5138
+ // preserve the order of operands
5139
+ if (Swapped)
5140
+ std::swap (Operands[0 ], Operands[1 ]);
5141
+ LaneVal = DCI.DAG .getNode (N->getOpcode (), DL, ScalarVT, Operands);
5142
+ } else {
5143
+ LaneVal = DCI.DAG .getUNDEF (ScalarVT);
5144
+ }
5145
+ }
5146
+ return DCI.DAG .getBuildVector (VectorVT, DL, Lanes);
5147
+ }
5148
+
5062
5149
// / PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
5063
5150
// /
5064
5151
static SDValue PerformADDCombine (SDNode *N,
5065
5152
TargetLowering::DAGCombinerInfo &DCI,
5066
5153
CodeGenOptLevel OptLevel) {
5067
- if (OptLevel == CodeGenOptLevel::None)
5068
- return SDValue ();
5069
-
5070
5154
SDValue N0 = N->getOperand (0 );
5071
5155
SDValue N1 = N->getOperand (1 );
5072
5156
5073
5157
// Skip non-integer, non-scalar case
5074
5158
EVT VT = N0.getValueType ();
5075
- if (VT.isVector () || VT != MVT::i32 )
5159
+ if (VT.isVector ())
5160
+ return PerformPackedOpCombine (N, DCI);
5161
+ if (VT != MVT::i32 )
5162
+ return SDValue ();
5163
+
5164
+ if (OptLevel == CodeGenOptLevel::None)
5076
5165
return SDValue ();
5077
5166
5078
5167
// First try with the default operand order.
@@ -5092,7 +5181,10 @@ static SDValue PerformFADDCombine(SDNode *N,
5092
5181
SDValue N1 = N->getOperand (1 );
5093
5182
5094
5183
EVT VT = N0.getValueType ();
5095
- if (VT.isVector () || !(VT == MVT::f32 || VT == MVT::f64 ))
5184
+ if (VT.isVector ())
5185
+ return PerformPackedOpCombine (N, DCI);
5186
+
5187
+ if (!(VT == MVT::f32 || VT == MVT::f64 ))
5096
5188
return SDValue ();
5097
5189
5098
5190
// First try with the default operand order.
@@ -5195,7 +5287,7 @@ static SDValue PerformANDCombine(SDNode *N,
5195
5287
DCI.CombineTo (N, Val, AddTo);
5196
5288
}
5197
5289
5198
- return SDValue ( );
5290
+ return PerformPackedOpCombine (N, DCI );
5199
5291
}
5200
5292
5201
5293
static SDValue PerformREMCombine (SDNode *N,
@@ -5676,6 +5768,16 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5676
5768
return PerformADDCombine (N, DCI, OptLevel);
5677
5769
case ISD::FADD:
5678
5770
return PerformFADDCombine (N, DCI, OptLevel);
5771
+ case ISD::FMUL:
5772
+ case ISD::FMINNUM:
5773
+ case ISD::FMAXIMUM:
5774
+ case ISD::UMIN:
5775
+ case ISD::UMAX:
5776
+ case ISD::SMIN:
5777
+ case ISD::SMAX:
5778
+ case ISD::OR:
5779
+ case ISD::XOR:
5780
+ return PerformPackedOpCombine (N, DCI);
5679
5781
case ISD::MUL:
5680
5782
return PerformMULCombine (N, DCI, OptLevel);
5681
5783
case ISD::SHL:
0 commit comments