-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[NVPTX] add combiner rule for final packed op in reduction #143943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[NVPTX] add combiner rule for final packed op in reduction #143943
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-nvptx Author: Princeton Ferro (Prince781) ChangesFor vector reductions, the final result needs to be a scalar. The default ex: lowering of
We wish to replace
...so that we get:
So for this example, this rule will replace Patch is 27.57 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143943.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d6a134d9abafd..7e36e5b526932 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -843,6 +843,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
if (STI.allowFP16Math() || STI.hasBF16Math())
setTargetDAGCombine(ISD::SETCC);
+ // Combine reduction operations on packed types (e.g. fadd.f16x2) with vector
+ // shuffles when one of their lanes is a no-op.
+ if (STI.allowFP16Math() || STI.hasBF16Math())
+ // already added above: FADD, ADD, AND
+ setTargetDAGCombine({ISD::FMUL, ISD::FMINIMUM, ISD::FMAXIMUM, ISD::UMIN,
+ ISD::UMAX, ISD::SMIN, ISD::SMAX, ISD::OR, ISD::XOR});
+
// Promote fp16 arithmetic if fp16 hardware isn't available or the
// user passed --nvptx-no-fp16-math. The flag is useful because,
// although sm_53+ GPUs have some sort of FP16 support in
@@ -5059,20 +5066,102 @@ static SDValue PerformStoreRetvalCombine(SDNode *N) {
return PerformStoreCombineHelper(N, 2, 0);
}
+/// For vector reductions, the final result needs to be a scalar. The default
+/// expansion will use packed ops (ex. fadd.f16x2) even for the final operation.
+/// This requires a packed operation where one of the lanes is undef.
+///
+/// ex: lowering of vecreduce_fadd(V) where V = v4f16<a b c d>
+///
+/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
+/// v2: v2f16 = vector_shuffle<1,u> v1, undef:v2f16 (== <b+d undef>)
+/// v3: v2f16 = fadd reassoc v2, v1 (== <b+d+a+c undef>)
+/// vR: f16 = extractelt v3, 1
+///
+/// We wish to replace vR, v3, and v2 with:
+/// vR: f16 = fadd reassoc (extractelt v1, 1) (extractelt v1, 0)
+///
+/// ...so that we get:
+/// v1: v2f16 = fadd reassoc v2f16<a b>, v2f16<c d> (== <a+c b+d>)
+/// s1: f16 = extractelt v1, 1
+/// s2: f16 = extractelt v1, 0
+/// vR: f16 = fadd reassoc s1, s2 (== a+c+b+d)
+///
+/// So for this example, this rule will replace v3 and v2, returning a vector
+/// with the result in lane 0 and an undef in lane 1, which we expect will be
+/// folded into the extractelt in vR.
+static SDValue PerformPackedOpCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ // Convert:
+ // (fop.x2 (vector_shuffle<i,u> A), B) -> ((fop A:i, B:0), undef)
+ // ...or...
+ // (fop.x2 (vector_shuffle<u,i> A), B) -> (undef, (fop A:i, B:1))
+ // ...where i is a valid index and u is poison.
+ const EVT VectorVT = N->getValueType(0);
+ if (!Isv2x16VT(VectorVT))
+ return SDValue();
+
+ SDLoc DL(N);
+
+ SDValue ShufOp = N->getOperand(0);
+ SDValue VectOp = N->getOperand(1);
+ bool Swapped = false;
+
+ // canonicalize shuffle to op0
+ if (VectOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
+ std::swap(ShufOp, VectOp);
+ Swapped = true;
+ }
+
+ if (ShufOp.getOpcode() != ISD::VECTOR_SHUFFLE)
+ return SDValue();
+
+ auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
+ int LiveLane; // exclusively live lane
+ for (LiveLane = 0; LiveLane < 2; ++LiveLane) {
+ // check if the current lane is live and the other lane is dead
+ if (ShuffleOp->getMaskElt(LiveLane) != PoisonMaskElem &&
+ ShuffleOp->getMaskElt(!LiveLane) == PoisonMaskElem)
+ break;
+ }
+ if (LiveLane == 2)
+ return SDValue();
+
+ int ElementIdx = ShuffleOp->getMaskElt(LiveLane);
+ const EVT ScalarVT = VectorVT.getScalarType();
+ SDValue Lanes[2] = {};
+ for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
+ if (LaneID == (unsigned)LiveLane) {
+ SDValue Operands[2] = {
+ DCI.DAG.getExtractVectorElt(DL, ScalarVT, ShufOp.getOperand(0),
+ ElementIdx),
+ DCI.DAG.getExtractVectorElt(DL, ScalarVT, VectOp, LiveLane)};
+ // preserve the order of operands
+ if (Swapped)
+ std::swap(Operands[0], Operands[1]);
+ LaneVal = DCI.DAG.getNode(N->getOpcode(), DL, ScalarVT, Operands);
+ } else {
+ LaneVal = DCI.DAG.getUNDEF(ScalarVT);
+ }
+ }
+ return DCI.DAG.getBuildVector(VectorVT, DL, Lanes);
+}
+
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
///
static SDValue PerformADDCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
- if (OptLevel == CodeGenOptLevel::None)
- return SDValue();
-
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
// Skip non-integer, non-scalar case
EVT VT = N0.getValueType();
- if (VT.isVector() || VT != MVT::i32)
+ if (VT.isVector())
+ return PerformPackedOpCombine(N, DCI);
+ if (VT != MVT::i32)
+ return SDValue();
+
+ if (OptLevel == CodeGenOptLevel::None)
return SDValue();
// First try with the default operand order.
@@ -5092,7 +5181,10 @@ static SDValue PerformFADDCombine(SDNode *N,
SDValue N1 = N->getOperand(1);
EVT VT = N0.getValueType();
- if (VT.isVector() || !(VT == MVT::f32 || VT == MVT::f64))
+ if (VT.isVector())
+ return PerformPackedOpCombine(N, DCI);
+
+ if (!(VT == MVT::f32 || VT == MVT::f64))
return SDValue();
// First try with the default operand order.
@@ -5195,7 +5287,7 @@ static SDValue PerformANDCombine(SDNode *N,
DCI.CombineTo(N, Val, AddTo);
}
- return SDValue();
+ return PerformPackedOpCombine(N, DCI);
}
static SDValue PerformREMCombine(SDNode *N,
@@ -5676,6 +5768,16 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformADDCombine(N, DCI, OptLevel);
case ISD::FADD:
return PerformFADDCombine(N, DCI, OptLevel);
+ case ISD::FMUL:
+ case ISD::FMINNUM:
+ case ISD::FMAXIMUM:
+ case ISD::UMIN:
+ case ISD::UMAX:
+ case ISD::SMIN:
+ case ISD::SMAX:
+ case ISD::OR:
+ case ISD::XOR:
+ return PerformPackedOpCombine(N, DCI);
case ISD::MUL:
return PerformMULCombine(N, DCI, OptLevel);
case ISD::SHL:
diff --git a/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll b/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll
index d5b451dad7bc3..ca03550bdefcd 100644
--- a/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll
+++ b/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll
@@ -5,10 +5,10 @@
; RUN: %if ptxas-12.8 %{ llc < %s -mcpu=sm_80 -mattr=+ptx70 -O0 \
; RUN: -disable-post-ra -verify-machineinstrs \
; RUN: | %ptxas-verify -arch=sm_80 %}
-; RUN: llc < %s -mcpu=sm_100 -mattr=+ptx87 -O0 \
+; RUN: llc < %s -mcpu=sm_100 -mattr=+ptx86 -O0 \
; RUN: -disable-post-ra -verify-machineinstrs \
; RUN: | FileCheck -check-prefixes CHECK,CHECK-SM100 %s
-; RUN: %if ptxas-12.8 %{ llc < %s -mcpu=sm_100 -mattr=+ptx87 -O0 \
+; RUN: %if ptxas-12.8 %{ llc < %s -mcpu=sm_100 -mattr=+ptx86 -O0 \
; RUN: -disable-post-ra -verify-machineinstrs \
; RUN: | %ptxas-verify -arch=sm_100 %}
target triple = "nvptx64-nvidia-cuda"
@@ -43,45 +43,22 @@ define half @reduce_fadd_half(<8 x half> %in) {
}
define half @reduce_fadd_half_reassoc(<8 x half> %in) {
-; CHECK-SM80-LABEL: reduce_fadd_half_reassoc(
-; CHECK-SM80: {
-; CHECK-SM80-NEXT: .reg .b16 %rs<6>;
-; CHECK-SM80-NEXT: .reg .b32 %r<10>;
-; CHECK-SM80-EMPTY:
-; CHECK-SM80-NEXT: // %bb.0:
-; CHECK-SM80-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_reassoc_param_0];
-; CHECK-SM80-NEXT: add.rn.f16x2 %r5, %r2, %r4;
-; CHECK-SM80-NEXT: add.rn.f16x2 %r6, %r1, %r3;
-; CHECK-SM80-NEXT: add.rn.f16x2 %r7, %r6, %r5;
-; CHECK-SM80-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs1}, %r7; }
-; CHECK-SM80-NEXT: // implicit-def: %rs2
-; CHECK-SM80-NEXT: mov.b32 %r8, {%rs1, %rs2};
-; CHECK-SM80-NEXT: add.rn.f16x2 %r9, %r7, %r8;
-; CHECK-SM80-NEXT: { .reg .b16 tmp; mov.b32 {%rs3, tmp}, %r9; }
-; CHECK-SM80-NEXT: mov.b16 %rs4, 0x0000;
-; CHECK-SM80-NEXT: add.rn.f16 %rs5, %rs3, %rs4;
-; CHECK-SM80-NEXT: st.param.b16 [func_retval0], %rs5;
-; CHECK-SM80-NEXT: ret;
-;
-; CHECK-SM100-LABEL: reduce_fadd_half_reassoc(
-; CHECK-SM100: {
-; CHECK-SM100-NEXT: .reg .b16 %rs<6>;
-; CHECK-SM100-NEXT: .reg .b32 %r<10>;
-; CHECK-SM100-EMPTY:
-; CHECK-SM100-NEXT: // %bb.0:
-; CHECK-SM100-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_reassoc_param_0];
-; CHECK-SM100-NEXT: add.rn.f16x2 %r5, %r2, %r4;
-; CHECK-SM100-NEXT: add.rn.f16x2 %r6, %r1, %r3;
-; CHECK-SM100-NEXT: add.rn.f16x2 %r7, %r6, %r5;
-; CHECK-SM100-NEXT: mov.b32 {_, %rs1}, %r7;
-; CHECK-SM100-NEXT: // implicit-def: %rs2
-; CHECK-SM100-NEXT: mov.b32 %r8, {%rs1, %rs2};
-; CHECK-SM100-NEXT: add.rn.f16x2 %r9, %r7, %r8;
-; CHECK-SM100-NEXT: mov.b32 {%rs3, _}, %r9;
-; CHECK-SM100-NEXT: mov.b16 %rs4, 0x0000;
-; CHECK-SM100-NEXT: add.rn.f16 %rs5, %rs3, %rs4;
-; CHECK-SM100-NEXT: st.param.b16 [func_retval0], %rs5;
-; CHECK-SM100-NEXT: ret;
+; CHECK-LABEL: reduce_fadd_half_reassoc(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<6>;
+; CHECK-NEXT: .reg .b32 %r<8>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_reassoc_param_0];
+; CHECK-NEXT: add.rn.f16x2 %r5, %r2, %r4;
+; CHECK-NEXT: add.rn.f16x2 %r6, %r1, %r3;
+; CHECK-NEXT: add.rn.f16x2 %r7, %r6, %r5;
+; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r7;
+; CHECK-NEXT: add.rn.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: mov.b16 %rs4, 0x0000;
+; CHECK-NEXT: add.rn.f16 %rs5, %rs3, %rs4;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs5;
+; CHECK-NEXT: ret;
%res = call reassoc half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
ret half %res
}
@@ -205,41 +182,20 @@ define half @reduce_fmul_half(<8 x half> %in) {
}
define half @reduce_fmul_half_reassoc(<8 x half> %in) {
-; CHECK-SM80-LABEL: reduce_fmul_half_reassoc(
-; CHECK-SM80: {
-; CHECK-SM80-NEXT: .reg .b16 %rs<4>;
-; CHECK-SM80-NEXT: .reg .b32 %r<10>;
-; CHECK-SM80-EMPTY:
-; CHECK-SM80-NEXT: // %bb.0:
-; CHECK-SM80-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_fmul_half_reassoc_param_0];
-; CHECK-SM80-NEXT: mul.rn.f16x2 %r5, %r2, %r4;
-; CHECK-SM80-NEXT: mul.rn.f16x2 %r6, %r1, %r3;
-; CHECK-SM80-NEXT: mul.rn.f16x2 %r7, %r6, %r5;
-; CHECK-SM80-NEXT: { .reg .b16 tmp; mov.b32 {tmp, %rs1}, %r7; }
-; CHECK-SM80-NEXT: // implicit-def: %rs2
-; CHECK-SM80-NEXT: mov.b32 %r8, {%rs1, %rs2};
-; CHECK-SM80-NEXT: mul.rn.f16x2 %r9, %r7, %r8;
-; CHECK-SM80-NEXT: { .reg .b16 tmp; mov.b32 {%rs3, tmp}, %r9; }
-; CHECK-SM80-NEXT: st.param.b16 [func_retval0], %rs3;
-; CHECK-SM80-NEXT: ret;
-;
-; CHECK-SM100-LABEL: reduce_fmul_half_reassoc(
-; CHECK-SM100: {
-; CHECK-SM100-NEXT: .reg .b16 %rs<4>;
-; CHECK-SM100-NEXT: .reg .b32 %r<10>;
-; CHECK-SM100-EMPTY:
-; CHECK-SM100-NEXT: // %bb.0:
-; CHECK-SM100-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_fmul_half_reassoc_param_0];
-; CHECK-SM100-NEXT: mul.rn.f16x2 %r5, %r2, %r4;
-; CHECK-SM100-NEXT: mul.rn.f16x2 %r6, %r1, %r3;
-; CHECK-SM100-NEXT: mul.rn.f16x2 %r7, %r6, %r5;
-; CHECK-SM100-NEXT: mov.b32 {_, %rs1}, %r7;
-; CHECK-SM100-NEXT: // implicit-def: %rs2
-; CHECK-SM100-NEXT: mov.b32 %r8, {%rs1, %rs2};
-; CHECK-SM100-NEXT: mul.rn.f16x2 %r9, %r7, %r8;
-; CHECK-SM100-NEXT: mov.b32 {%rs3, _}, %r9;
-; CHECK-SM100-NEXT: st.param.b16 [func_retval0], %rs3;
-; CHECK-SM100-NEXT: ret;
+; CHECK-LABEL: reduce_fmul_half_reassoc(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<4>;
+; CHECK-NEXT: .reg .b32 %r<8>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_fmul_half_reassoc_param_0];
+; CHECK-NEXT: mul.rn.f16x2 %r5, %r2, %r4;
+; CHECK-NEXT: mul.rn.f16x2 %r6, %r1, %r3;
+; CHECK-NEXT: mul.rn.f16x2 %r7, %r6, %r5;
+; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r7;
+; CHECK-NEXT: mul.rn.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: st.param.b16 [func_retval0], %rs3;
+; CHECK-NEXT: ret;
%res = call reassoc half @llvm.vector.reduce.fmul(half 1.0, <8 x half> %in)
ret half %res
}
@@ -401,7 +357,6 @@ define half @reduce_fmax_half_reassoc_nonpow2(<7 x half> %in) {
; Check straight-line reduction.
define float @reduce_fmax_float(<8 x float> %in) {
-;
; CHECK-LABEL: reduce_fmax_float(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<16>;
@@ -423,7 +378,6 @@ define float @reduce_fmax_float(<8 x float> %in) {
}
define float @reduce_fmax_float_reassoc(<8 x float> %in) {
-;
; CHECK-LABEL: reduce_fmax_float_reassoc(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<16>;
@@ -445,7 +399,6 @@ define float @reduce_fmax_float_reassoc(<8 x float> %in) {
}
define float @reduce_fmax_float_reassoc_nonpow2(<7 x float> %in) {
-;
; CHECK-LABEL: reduce_fmax_float_reassoc_nonpow2(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<14>;
@@ -533,7 +486,6 @@ define half @reduce_fmin_half_reassoc_nonpow2(<7 x half> %in) {
; Check straight-line reduction.
define float @reduce_fmin_float(<8 x float> %in) {
-;
; CHECK-LABEL: reduce_fmin_float(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<16>;
@@ -555,7 +507,6 @@ define float @reduce_fmin_float(<8 x float> %in) {
}
define float @reduce_fmin_float_reassoc(<8 x float> %in) {
-;
; CHECK-LABEL: reduce_fmin_float_reassoc(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<16>;
@@ -665,7 +616,6 @@ define half @reduce_fmaximum_half_reassoc_nonpow2(<7 x half> %in) {
; Check straight-line reduction.
define float @reduce_fmaximum_float(<8 x float> %in) {
-;
; CHECK-LABEL: reduce_fmaximum_float(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<16>;
@@ -687,7 +637,6 @@ define float @reduce_fmaximum_float(<8 x float> %in) {
}
define float @reduce_fmaximum_float_reassoc(<8 x float> %in) {
-;
; CHECK-LABEL: reduce_fmaximum_float_reassoc(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<16>;
@@ -709,7 +658,6 @@ define float @reduce_fmaximum_float_reassoc(<8 x float> %in) {
}
define float @reduce_fmaximum_float_reassoc_nonpow2(<7 x float> %in) {
-;
; CHECK-LABEL: reduce_fmaximum_float_reassoc_nonpow2(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<14>;
@@ -797,7 +745,6 @@ define half @reduce_fminimum_half_reassoc_nonpow2(<7 x half> %in) {
; Check straight-line reduction.
define float @reduce_fminimum_float(<8 x float> %in) {
-;
; CHECK-LABEL: reduce_fminimum_float(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<16>;
@@ -819,7 +766,6 @@ define float @reduce_fminimum_float(<8 x float> %in) {
}
define float @reduce_fminimum_float_reassoc(<8 x float> %in) {
-;
; CHECK-LABEL: reduce_fminimum_float_reassoc(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<16>;
@@ -841,7 +787,6 @@ define float @reduce_fminimum_float_reassoc(<8 x float> %in) {
}
define float @reduce_fminimum_float_reassoc_nonpow2(<7 x float> %in) {
-;
; CHECK-LABEL: reduce_fminimum_float_reassoc_nonpow2(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<14>;
@@ -888,20 +833,17 @@ define i16 @reduce_add_i16(<8 x i16> %in) {
; CHECK-SM100-LABEL: reduce_add_i16(
; CHECK-SM100: {
; CHECK-SM100-NEXT: .reg .b16 %rs<4>;
-; CHECK-SM100-NEXT: .reg .b32 %r<11>;
+; CHECK-SM100-NEXT: .reg .b32 %r<9>;
; CHECK-SM100-EMPTY:
; CHECK-SM100-NEXT: // %bb.0:
; CHECK-SM100-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_add_i16_param_0];
; CHECK-SM100-NEXT: add.s16x2 %r5, %r2, %r4;
; CHECK-SM100-NEXT: add.s16x2 %r6, %r1, %r3;
; CHECK-SM100-NEXT: add.s16x2 %r7, %r6, %r5;
-; CHECK-SM100-NEXT: mov.b32 {_, %rs1}, %r7;
-; CHECK-SM100-NEXT: // implicit-def: %rs2
-; CHECK-SM100-NEXT: mov.b32 %r8, {%rs1, %rs2};
-; CHECK-SM100-NEXT: add.s16x2 %r9, %r7, %r8;
-; CHECK-SM100-NEXT: mov.b32 {%rs3, _}, %r9;
-; CHECK-SM100-NEXT: cvt.u32.u16 %r10, %rs3;
-; CHECK-SM100-NEXT: st.param.b32 [func_retval0], %r10;
+; CHECK-SM100-NEXT: mov.b32 {%rs1, %rs2}, %r7;
+; CHECK-SM100-NEXT: add.s16 %rs3, %rs1, %rs2;
+; CHECK-SM100-NEXT: cvt.u32.u16 %r8, %rs3;
+; CHECK-SM100-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-SM100-NEXT: ret;
%res = call i16 @llvm.vector.reduce.add(<8 x i16> %in)
ret i16 %res
@@ -1114,20 +1056,17 @@ define i16 @reduce_umax_i16(<8 x i16> %in) {
; CHECK-SM100-LABEL: reduce_umax_i16(
; CHECK-SM100: {
; CHECK-SM100-NEXT: .reg .b16 %rs<4>;
-; CHECK-SM100-NEXT: .reg .b32 %r<11>;
+; CHECK-SM100-NEXT: .reg .b32 %r<9>;
; CHECK-SM100-EMPTY:
; CHECK-SM100-NEXT: // %bb.0:
; CHECK-SM100-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_umax_i16_param_0];
; CHECK-SM100-NEXT: max.u16x2 %r5, %r2, %r4;
; CHECK-SM100-NEXT: max.u16x2 %r6, %r1, %r3;
; CHECK-SM100-NEXT: max.u16x2 %r7, %r6, %r5;
-; CHECK-SM100-NEXT: mov.b32 {_, %rs1}, %r7;
-; CHECK-SM100-NEXT: // implicit-def: %rs2
-; CHECK-SM100-NEXT: mov.b32 %r8, {%rs1, %rs2};
-; CHECK-SM100-NEXT: max.u16x2 %r9, %r7, %r8;
-; CHECK-SM100-NEXT: mov.b32 {%rs3, _}, %r9;
-; CHECK-SM100-NEXT: cvt.u32.u16 %r10, %rs3;
-; CHECK-SM100-NEXT: st.param.b32 [func_retval0], %r10;
+; CHECK-SM100-NEXT: mov.b32 {%rs1, %rs2}, %r7;
+; CHECK-SM100-NEXT: max.u16 %rs3, %rs1, %rs2;
+; CHECK-SM100-NEXT: cvt.u32.u16 %r8, %rs3;
+; CHECK-SM100-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-SM100-NEXT: ret;
%res = call i16 @llvm.vector.reduce.umax(<8 x i16> %in)
ret i16 %res
@@ -1248,20 +1187,17 @@ define i16 @reduce_umin_i16(<8 x i16> %in) {
; CHECK-SM100-LABEL: reduce_umin_i16(
; CHECK-SM100: {
; CHECK-SM100-NEXT: .reg .b16 %rs<4>;
-; CHECK-SM100-NEXT: .reg .b32 %r<11>;
+; CHECK-SM100-NEXT: .reg .b32 %r<9>;
; CHECK-SM100-EMPTY:
; CHECK-SM100-NEXT: // %bb.0:
; CHECK-SM100-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_umin_i16_param_0];
; CHECK-SM100-NEXT: min.u16x2 %r5, %r2, %r4;
; CHECK-SM100-NEXT: min.u16x2 %r6, %r1, %r3;
; CHECK-SM100-NEXT: min.u16x2 %r7, %r6, %r5;
-; CHECK-SM100-NEXT: mov.b32 {_, %rs1}, %r7;
-; CHECK-SM100-NEXT: // implicit-def: %rs2
-; CHECK-SM100-NEXT: mov.b32 %r8, {%rs1, %rs2};
-; CHECK-SM100-NEXT: min.u16x2 %r9, %r7, %r8;
-; CHECK-SM100-NEXT: mov.b32 {%rs3, _}, %r9;
-; CHECK-SM100-NEXT: cvt.u32.u16 %r10, %rs3;
-; CHECK-SM100-NEXT: st.param.b32 [func_retval0], %r10;
+; CHECK-SM100-NEXT: mov.b32 {%rs1, %rs2}, %r7;
+; CHECK-SM100-NEXT: min.u16 %rs3, %rs1, %rs2;
+; CHECK-SM100-NEXT: cvt.u32.u16 %r8, %rs3;
+; CHECK-SM100-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-SM100-NEXT: ret;
%res = call i16 @llvm.vector.reduce.umin(<8 x i16> %in)
ret i16 %res
@@ -1382,20 +1318,17 @@ define i16 @reduce_smax_i16(<8 x i16> %in) {
; CHECK-SM100-LABEL: reduce_smax_i16(
; CHECK-SM100: {
; CHECK-SM100-NEXT: .reg .b16 %rs<4>;
-; CHECK-SM100-NEXT: .reg .b32 %r<11>;
+; CHECK-SM100-NEXT: .reg .b32 %r<9>;
; CHECK-SM100-EMPTY:
; CHECK-SM100-NEXT: // %bb.0:
; CHECK-SM100-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_smax_i16_param_0];
; CHECK-SM100-NEXT: max.s16x2 %r5, %r2, %r4;
; CHECK-SM100-NEXT: max.s16x2 %r6, %r1, %r3;
; CHECK-SM100-NEXT: max.s16x2 %r7, %r6, %r5;
-; CHECK-SM100-NEXT: mov.b32 {_, %rs1}, %r7;
-; CHECK-SM100-NEXT: // implicit-def: %rs2
-; CHECK-SM100-NEXT: mov.b32 %r8, {%rs1, %rs2};
-; CHECK-SM100-NEXT: max.s16x2 %r9, %r7, %r8;
-; CHECK-SM100-NEXT: mov.b32 {%rs3, _}, %r9;
-; CHECK-SM100-NEXT: cvt.u32.u16 %r10, %rs3;
-; CHECK-SM100-NEXT: st.param.b32 [func_retval0], %r10;
+; CHECK-SM100-NEXT: mov.b32 {%rs1, %rs2}, %r7;
+; CHECK-SM100-NEXT: max.s16 %rs3, %rs1, %rs2;
+; CHECK-SM100-NEXT: cvt.u32.u16 %r8, %rs3;
+; CHECK-SM100-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-SM100-NEXT: ret;
%res = call i16 @llvm.vector.reduce.smax(<8 x i16> %in)
ret i16 %res
@@ -1516,20 +1449,17 @@ define i16 @reduce_smin_i16(<8 x i16> %in) {
; CHECK-SM100-LABEL: reduce_smin_i16(
; CHECK-SM100: {
; CHECK-SM100-NEXT: .reg .b16 %rs<4>;
-; CHECK-SM100-NEXT: .reg .b32 %r<11>;
+; CHECK-SM100-NEXT: .reg .b32 %r<9>;
; CHECK-SM100-EMPTY:
; CHECK-SM100-NEXT: // %bb.0:
; CHECK-SM100-NEXT: ld.param.v4.b32 {%r1, %r2, %r3, %r4}, [reduce_smin_i16_param_0];
; CHECK-SM100-NEXT: min.s16x2 %r5, %r2, %r4;
; CHECK-SM100-NEXT: min.s16x2 %r6, %r1, %r3;
; CHECK-SM100-NEXT: min.s16x2 %r7, %r6, %r5;
-; CHECK-SM100-NEXT: mov.b32 {_, %rs1}, %r7;
-; CHECK-SM100-NEXT: // implicit-def: %rs2
-; CHECK-SM100-NEXT: mov.b32...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be a fix for an inefficient lowering of vecreduce_fadd
.
NVPTX isn't custom lowering vecreduce_fadd
, correct? If that is the case, then this is an issue other targets could/do face.
Given that, why not improve the vecreduce_fadd
lowering or implement this in DAGCombiner?
I too wonder if this can be made more general. Is there a reason this optimization needs to be NVPTX specific? Maybe we could update the lowering of these operations based on if vector shuffle is legal. Or maybe we could add a DAG combiner rule that looks for BUILD_VECTOR <- [BIN_OP] <- extract_vector_elt and coverts to extract_vector_elt <- [BIN_OP]. |
The default expansion of reduction intrinsics happens in
Here's a trace of what the old behavior is/was on this IR:
target triple = "nvptx64-nvidia-cuda"
target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
define i16 @reduce_smin_i16(<8 x i16> %in) {
%res = call i16 @llvm.vector.reduce.smin(<8 x i16> %in)
ret i16 %res
}
|
I ported this to DAGCombiner, but the optimization only winds up benefiting NVPTX. What do you guys think? Should we keep it in NVPTXISelLowering? |
f879526
to
371a3ad
Compare
Some targets like x86 seem to handle this sequence just fine. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
This reverts commit 8cbda00.
The result of a reduction needs to be a scalar. When expressed as a sequence of vector ops, the final op needs to reduce two (or more) lanes belonging to the same register. On some targets a shuffle is not available and this results in two extra movs to setup another vector register with the lanes swapped. This pattern is now handled better by turning it into a scalar op.
This reverts commit 371a3ad.
This reverts commit 9d84bd3.
371a3ad
to
dfa1a0f
Compare
So I've decided to merge this back into NVPTXISelLowering. |
For vector reductions, the final result needs to be a scalar. The default
expansion will use packed ops (ex.
fadd.f16x2
) even for the final operation.This requires a packed operation where one of the lanes is
undef
.ex: lowering of
vecreduce_fadd(V)
whereV = v4f16<a b c d>
We wish to replace
vR
,v3
, andv2
with:...so that we get:
So for this example, this rule will replace
v3
andv2
, returning a vectorwith the result in lane 0 and an undef in lane 1, which we expect will be
folded into the
extractelt
invR
.