-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVPTX] fold movs into loads and stores #144581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Princeton Ferro (Prince781) ChangesFold movs into loads and stores by increasing the number of return values or operands. For example:
...becomes...
Patch is 327.00 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144581.diff 23 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 492f4ab76fdbb..e736b2ca6a151 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -238,18 +238,11 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
return std::nullopt;
LLVM_FALLTHROUGH;
case MVT::v2i8:
- case MVT::v2i16:
case MVT::v2i32:
case MVT::v2i64:
- case MVT::v2f16:
- case MVT::v2bf16:
case MVT::v2f32:
case MVT::v2f64:
- case MVT::v4i8:
- case MVT::v4i16:
case MVT::v4i32:
- case MVT::v4f16:
- case MVT::v4bf16:
case MVT::v4f32:
// This is a "native" vector type
return std::pair(NumElts, EltVT);
@@ -262,6 +255,13 @@ getVectorLoweringShape(EVT VectorEVT, bool CanLowerTo256Bit) {
if (!CanLowerTo256Bit)
return std::nullopt;
LLVM_FALLTHROUGH;
+ case MVT::v2i16: // <1 x i16x2>
+ case MVT::v2f16: // <1 x f16x2>
+ case MVT::v2bf16: // <1 x bf16x2>
+ case MVT::v4i8: // <1 x i8x4>
+ case MVT::v4i16: // <2 x i16x2>
+ case MVT::v4f16: // <2 x f16x2>
+ case MVT::v4bf16: // <2 x bf16x2>
case MVT::v8i8: // <2 x i8x4>
case MVT::v8f16: // <4 x f16x2>
case MVT::v8bf16: // <4 x bf16x2>
@@ -845,7 +845,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3464,19 +3464,16 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
unsigned I = 0;
for (const unsigned NumElts : VectorInfo) {
const EVT EltVT = VTs[I];
- const EVT LoadVT = [&]() -> EVT {
- // i1 is loaded/stored as i8.
- if (EltVT == MVT::i1)
- return MVT::i8;
- // getLoad needs a vector type, but it can't handle
- // vectors which contain v2f16 or v2bf16 elements. So we must load
- // using i32 here and then bitcast back.
- if (EltVT.isVector())
- return MVT::getIntegerVT(EltVT.getFixedSizeInBits());
- return EltVT;
- }();
+ // i1 is loaded/stored as i8
+ const EVT LoadVT = EltVT == MVT::i1 ? MVT::i8 : EltVT;
+ // If the element is a packed type (ex. v2f16, v4i8, etc) holding
+ // multiple elements.
+ const unsigned PackingAmt =
+ LoadVT.isVector() ? LoadVT.getVectorNumElements() : 1;
+
+ const EVT VecVT = EVT::getVectorVT(
+ F->getContext(), LoadVT.getScalarType(), NumElts * PackingAmt);
- const EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
SDValue VecAddr = DAG.getObjectPtrOffset(
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
@@ -3496,8 +3493,10 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (P.getNode())
P.getNode()->setIROrder(Arg.getArgNo() + 1);
for (const unsigned J : llvm::seq(NumElts)) {
- SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, LoadVT, P,
- DAG.getIntPtrConstant(J, dl));
+ SDValue Elt =
+ DAG.getNode(LoadVT.isVector() ? ISD::EXTRACT_SUBVECTOR
+ : ISD::EXTRACT_VECTOR_ELT,
+ dl, LoadVT, P, DAG.getIntPtrConstant(J * PackingAmt, dl));
// Extend or truncate the element if necessary (e.g. an i8 is loaded
// into an i16 register)
@@ -3511,9 +3510,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
Elt);
} else if (ExpactedVT.bitsLT(Elt.getValueType())) {
Elt = DAG.getNode(ISD::TRUNCATE, dl, ExpactedVT, Elt);
- } else {
- // v2f16 was loaded as an i32. Now we must bitcast it back.
- Elt = DAG.getBitcast(EltVT, Elt);
}
InVals.push_back(Elt);
}
@@ -5047,26 +5043,229 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
return SDValue();
}
-static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
- std::size_t Back) {
+/// Combine extractelts into a load by increasing the number of return values.
+static SDValue
+combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ // Don't run this optimization before the legalizer
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ EVT ElemVT = N->getValueType(0);
+ if (!Isv2x16VT(ElemVT))
+ return SDValue();
+
+ // Check whether all outputs are either used by an extractelt or are
+ // glue/chain nodes
+ if (!all_of(N->uses(), [&](SDUse &U) {
+ return U.getValueType() != ElemVT ||
+ (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ // also check that the extractelt is used if this is an
+ // ISD::LOAD, otherwise it may be optimized by something else
+ (N->getOpcode() != ISD::LOAD || !U.getUser()->use_empty()));
+ }))
+ return SDValue();
+
+ auto *LD = cast<MemSDNode>(N);
+ EVT MemVT = LD->getMemoryVT();
+ SDLoc DL(LD);
+
+ // the new opcode after we double the number of operands
+ NVPTXISD::NodeType Opcode;
+ SmallVector<SDValue> Operands(LD->ops());
+ switch (LD->getOpcode()) {
+ // Any packed type is legal, so the legalizer will not have lowered ISD::LOAD
+ // -> NVPTXISD::Load. We have to do it here.
+ case ISD::LOAD:
+ Opcode = NVPTXISD::LoadV2;
+ {
+ Operands.push_back(DCI.DAG.getIntPtrConstant(
+ cast<LoadSDNode>(LD)->getExtensionType(), DL));
+ Align Alignment = LD->getAlign();
+ const auto &TD = DCI.DAG.getDataLayout();
+ Align PrefAlign =
+ TD.getPrefTypeAlign(MemVT.getTypeForEVT(*DCI.DAG.getContext()));
+ if (Alignment < PrefAlign) {
+ // This load is not sufficiently aligned, so bail out and let this
+ // vector load be scalarized. Note that we may still be able to emit
+ // smaller vector loads. For example, if we are loading a <4 x float>
+ // with an alignment of 8, this check will fail but the legalizer will
+ // try again with 2 x <2 x float>, which will succeed with an alignment
+ // of 8.
+ return SDValue();
+ }
+ }
+ break;
+ case NVPTXISD::LoadParamV2:
+ Opcode = NVPTXISD::LoadParamV4;
+ break;
+ case NVPTXISD::LoadV2:
+ Opcode = NVPTXISD::LoadV4;
+ break;
+ case NVPTXISD::LoadV4:
+ // PTX doesn't support v8 for 16-bit values
+ case NVPTXISD::LoadV8:
+ // PTX doesn't support the next doubling of outputs
+ return SDValue();
+ }
+
+ SmallVector<EVT> NewVTs;
+ for (EVT VT : LD->values()) {
+ if (VT == ElemVT) {
+ const EVT ScalarVT = ElemVT.getVectorElementType();
+ NewVTs.insert(NewVTs.end(), {ScalarVT, ScalarVT});
+ } else
+ NewVTs.push_back(VT);
+ }
+
+ // Create the new load
+ SDValue NewLoad =
+ DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
+ Operands, MemVT, LD->getMemOperand());
+
+ // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
+ // the outputs the same. These nodes will be optimized away in later
+ // DAGCombiner iterations.
+ SmallVector<SDValue> Results;
+ for (unsigned I = 0; I < NewLoad->getNumValues();) {
+ if (NewLoad->getValueType(I) == ElemVT.getVectorElementType()) {
+ Results.push_back(DCI.DAG.getBuildVector(
+ ElemVT, DL, {NewLoad.getValue(I), NewLoad.getValue(I + 1)}));
+ I += 2;
+ } else {
+ Results.push_back(NewLoad.getValue(I));
+ I += 1;
+ }
+ }
+
+ return DCI.DAG.getMergeValues(Results, DL);
+}
+
+/// Fold a packing mov into a store. This may help lower register pressure.
+///
+/// ex:
+/// v: v2f16 = build_vector a:f16, b:f16
+/// StoreRetval v
+///
+/// ...is turned into...
+///
+/// StoreRetvalV2 a:f16, b:f16
+static SDValue combinePackingMovIntoStore(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ unsigned Front, unsigned Back) {
+ // Don't run this optimization before the legalizer
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ // Get the type of the operands being stored.
+ EVT ElementVT = N->getOperand(Front).getValueType();
+
+ if (!Isv2x16VT(ElementVT))
+ return SDValue();
+
+ // We want to run this as late as possible since other optimizations may
+ // eliminate the BUILD_VECTORs.
+ if (!DCI.isAfterLegalizeDAG())
+ return SDValue();
+
+ auto *ST = cast<MemSDNode>(N);
+ EVT MemVT = ElementVT.getVectorElementType();
+
+ // The new opcode after we double the number of operands.
+ NVPTXISD::NodeType Opcode;
+ switch (N->getOpcode()) {
+ case NVPTXISD::StoreParam:
+ Opcode = NVPTXISD::StoreParamV2;
+ break;
+ case NVPTXISD::StoreParamV2:
+ Opcode = NVPTXISD::StoreParamV4;
+ break;
+ case NVPTXISD::StoreRetval:
+ Opcode = NVPTXISD::StoreRetvalV2;
+ break;
+ case NVPTXISD::StoreRetvalV2:
+ Opcode = NVPTXISD::StoreRetvalV4;
+ break;
+ case NVPTXISD::StoreV2:
+ MemVT = ST->getMemoryVT();
+ Opcode = NVPTXISD::StoreV4;
+ break;
+ case NVPTXISD::StoreV4:
+ // PTX doesn't support v8 for 16-bit values
+ case NVPTXISD::StoreParamV4:
+ case NVPTXISD::StoreRetvalV4:
+ case NVPTXISD::StoreV8:
+ // PTX doesn't support the next doubling of operands for these opcodes.
+ return SDValue();
+ default:
+ llvm_unreachable("Unhandled store opcode");
+ }
+
+ // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
+ // their elements.
+ SmallVector<SDValue, 4> Operands(N->ops().take_front(Front));
+ for (SDValue BV : N->ops().drop_front(Front).drop_back(Back)) {
+ if (BV.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+
+ // If the operand has multiple uses, this optimization can increase register
+ // pressure.
+ if (!BV.hasOneUse())
+ return SDValue();
+
+ // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
+ // any signs they may be folded by some other pattern or rule.
+ for (SDValue Op : BV->ops()) {
+ // Peek through bitcasts
+ if (Op.getOpcode() == ISD::BITCAST)
+ Op = Op.getOperand(0);
+
+ // This may be folded into a PRMT.
+ if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
+ Op->getOperand(0).getValueType() == MVT::i32)
+ return SDValue();
+
+ // This may be folded into cvt.bf16x2
+ if (Op.getOpcode() == ISD::FP_ROUND)
+ return SDValue();
+ }
+ Operands.insert(Operands.end(), {BV.getOperand(0), BV.getOperand(1)});
+ }
+ for (SDValue Op : N->ops().take_back(Back))
+ Operands.push_back(Op);
+
+ // Now we replace the store
+ return DCI.DAG.getMemIntrinsicNode(Opcode, SDLoc(N), N->getVTList(),
+ Operands, MemVT, ST->getMemOperand());
+}
+
+static SDValue PerformStoreCombineHelper(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ unsigned Front, unsigned Back) {
if (all_of(N->ops().drop_front(Front).drop_back(Back),
[](const SDUse &U) { return U.get()->isUndef(); }))
// Operand 0 is the previous value in the chain. Cannot return EntryToken
// as the previous value will become unused and eliminated later.
return N->getOperand(0);
- return SDValue();
+ return combinePackingMovIntoStore(N, DCI, Front, Back);
+}
+
+static SDValue PerformStoreCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ return combinePackingMovIntoStore(N, DCI, 1, 2);
}
-static SDValue PerformStoreParamCombine(SDNode *N) {
+static SDValue PerformStoreParamCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
// Operands from the 3rd to the 2nd last one are the values to be stored.
// {Chain, ArgID, Offset, Val, Glue}
- return PerformStoreCombineHelper(N, 3, 1);
+ return PerformStoreCombineHelper(N, DCI, 3, 1);
}
-static SDValue PerformStoreRetvalCombine(SDNode *N) {
+static SDValue PerformStoreRetvalCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
// Operands from the 2nd to the last one are the values to be stored
- return PerformStoreCombineHelper(N, 2, 0);
+ return PerformStoreCombineHelper(N, DCI, 2, 0);
}
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5697,14 +5896,22 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformREMCombine(N, DCI, OptLevel);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI, STI.getSmVersion());
+ case ISD::LOAD:
+ case NVPTXISD::LoadParamV2:
+ case NVPTXISD::LoadV2:
+ case NVPTXISD::LoadV4:
+ return combineUnpackingMovIntoLoad(N, DCI);
case NVPTXISD::StoreRetval:
case NVPTXISD::StoreRetvalV2:
case NVPTXISD::StoreRetvalV4:
- return PerformStoreRetvalCombine(N);
+ return PerformStoreRetvalCombine(N, DCI);
case NVPTXISD::StoreParam:
case NVPTXISD::StoreParamV2:
case NVPTXISD::StoreParamV4:
- return PerformStoreParamCombine(N);
+ return PerformStoreParamCombine(N, DCI);
+ case NVPTXISD::StoreV2:
+ case NVPTXISD::StoreV4:
+ return PerformStoreCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
return PerformEXTRACTCombine(N, DCI);
case ISD::VSELECT:
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 32225ed04e2d9..95af9c64a73ac 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -146,37 +146,35 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70: {
; SM70-NEXT: .reg .pred %p<3>;
; SM70-NEXT: .reg .b16 %rs<5>;
-; SM70-NEXT: .reg .b32 %r<24>;
+; SM70-NEXT: .reg .b32 %r<22>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
-; SM70-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
-; SM70-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
-; SM70-NEXT: mov.b32 {%rs1, %rs2}, %r2;
+; SM70-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_faddx2_param_0];
+; SM70-NEXT: ld.param.v2.b16 {%rs3, %rs4}, [test_faddx2_param_1];
+; SM70-NEXT: cvt.u32.u16 %r1, %rs4;
+; SM70-NEXT: shl.b32 %r2, %r1, 16;
; SM70-NEXT: cvt.u32.u16 %r3, %rs2;
; SM70-NEXT: shl.b32 %r4, %r3, 16;
-; SM70-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM70-NEXT: cvt.u32.u16 %r5, %rs4;
-; SM70-NEXT: shl.b32 %r6, %r5, 16;
-; SM70-NEXT: add.rn.f32 %r7, %r6, %r4;
-; SM70-NEXT: bfe.u32 %r8, %r7, 16, 1;
-; SM70-NEXT: add.s32 %r9, %r8, %r7;
-; SM70-NEXT: add.s32 %r10, %r9, 32767;
-; SM70-NEXT: setp.nan.f32 %p1, %r7, %r7;
-; SM70-NEXT: or.b32 %r11, %r7, 4194304;
-; SM70-NEXT: selp.b32 %r12, %r11, %r10, %p1;
+; SM70-NEXT: add.rn.f32 %r5, %r4, %r2;
+; SM70-NEXT: bfe.u32 %r6, %r5, 16, 1;
+; SM70-NEXT: add.s32 %r7, %r6, %r5;
+; SM70-NEXT: add.s32 %r8, %r7, 32767;
+; SM70-NEXT: setp.nan.f32 %p1, %r5, %r5;
+; SM70-NEXT: or.b32 %r9, %r5, 4194304;
+; SM70-NEXT: selp.b32 %r10, %r9, %r8, %p1;
+; SM70-NEXT: cvt.u32.u16 %r11, %rs3;
+; SM70-NEXT: shl.b32 %r12, %r11, 16;
; SM70-NEXT: cvt.u32.u16 %r13, %rs1;
; SM70-NEXT: shl.b32 %r14, %r13, 16;
-; SM70-NEXT: cvt.u32.u16 %r15, %rs3;
-; SM70-NEXT: shl.b32 %r16, %r15, 16;
-; SM70-NEXT: add.rn.f32 %r17, %r16, %r14;
-; SM70-NEXT: bfe.u32 %r18, %r17, 16, 1;
-; SM70-NEXT: add.s32 %r19, %r18, %r17;
-; SM70-NEXT: add.s32 %r20, %r19, 32767;
-; SM70-NEXT: setp.nan.f32 %p2, %r17, %r17;
-; SM70-NEXT: or.b32 %r21, %r17, 4194304;
-; SM70-NEXT: selp.b32 %r22, %r21, %r20, %p2;
-; SM70-NEXT: prmt.b32 %r23, %r22, %r12, 0x7632U;
-; SM70-NEXT: st.param.b32 [func_retval0], %r23;
+; SM70-NEXT: add.rn.f32 %r15, %r14, %r12;
+; SM70-NEXT: bfe.u32 %r16, %r15, 16, 1;
+; SM70-NEXT: add.s32 %r17, %r16, %r15;
+; SM70-NEXT: add.s32 %r18, %r17, 32767;
+; SM70-NEXT: setp.nan.f32 %p2, %r15, %r15;
+; SM70-NEXT: or.b32 %r19, %r15, 4194304;
+; SM70-NEXT: selp.b32 %r20, %r19, %r18, %p2;
+; SM70-NEXT: prmt.b32 %r21, %r20, %r10, 0x7632U;
+; SM70-NEXT: st.param.b32 [func_retval0], %r21;
; SM70-NEXT: ret;
;
; SM80-LABEL: test_faddx2(
@@ -184,31 +182,29 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: .reg .b32 %r<5>;
; SM80-EMPTY:
; SM80-NEXT: // %bb.0:
-; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
-; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
+; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
+; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
; SM80-NEXT: mov.b32 %r3, 1065369472;
-; SM80-NEXT: fma.rn.bf16x2 %r4, %r2, %r3, %r1;
+; SM80-NEXT: fma.rn.bf16x2 %r4, %r1, %r3, %r2;
; SM80-NEXT: st.param.b32 [func_retval0], %r4;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_faddx2(
; SM80-FTZ: {
; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
-; SM80-FTZ-NEXT: .reg .b32 %r<10>;
+; SM80-FTZ-NEXT: .reg .b32 %r<8>;
; SM80-FTZ-EMPTY:
; SM80-FTZ-NEXT: // %bb.0:
-; SM80-FTZ-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
-; SM80-FTZ-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
-; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r3, %rs1;
-; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r4, %rs3;
-; SM80-FTZ-NEXT: add.rn.ftz.f32 %r5, %r4, %r3;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r6, %rs2;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r7, %rs4;
-; SM80-FTZ-NEXT: add.rn.ftz.f32 %r8, %r7, %r6;
-; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r9, %r8, %r5;
-; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r9;
+; SM80-FTZ-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_faddx2_param_0];
+; SM80-FTZ-NEXT: ld.param.v2.b16 {%rs3, %rs4}, [test_faddx2_param_1];
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r1, %rs3;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r2, %rs1;
+; SM80-FTZ-NEXT: add.rn.ftz.f32 %r3, %r2, %r1;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r4, %rs4;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %r5, %rs2;
+; SM80-FTZ-NEXT: add.rn.ftz.f32 %r6, %r5, %r4;
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r7, %r6, %r3;
+; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r7;
; SM80-FTZ-NEXT: ret;
;
; SM90-LABEL: test_faddx2(
@@ -216,9 +212,9 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM90-NEXT: .reg .b32 %r<4>;
; SM90-EMPTY:
; SM90-NEXT: // %bb.0:
-; SM90-NEXT: ld.param.b32 %r1, [test_faddx2_param_1];
-; SM90-NEXT: ld.param.b32 %r2, [test_faddx2_param_0];
-; SM90-NEXT: add.rn.bf16x2 %r3, %r2, %r1;
+; SM90-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
+; SM90-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
+; SM90-NEXT: add.rn.bf16x2 %r3, %r1, %r2;
; SM90-NEXT: st.param.b32 [func_retval0], %r3;
; SM90-NEXT: ret;
%r = fadd <2 x bfloat> %a, %b
@@ -230,37 +226,35 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM70: {
; SM70-NEXT: .reg .pred %p<3>;
; SM70-NEXT: .reg .b16 %rs<5>;
-; SM70-NEXT: .reg .b32 %r<24>;
+; SM70-NEXT: .reg .b32 %r<22>;
; SM70-EMPTY:
; SM70-NEXT: // %bb.0:
-; SM70-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
-; SM70-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
-; SM70-NEXT: mov.b32 {%rs1, %rs2}, %r2;
+; SM70-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_fsubx2_param_0];
+; SM70-NEXT: ld.param.v2.b16 {%rs3, %rs4}, [test_fsubx2_param_1];
+; SM70-NEXT: cvt.u32.u16 %r1, %rs4;
+; SM70-NEXT: shl.b32 %r2, %r1, 16;
; SM70-NEXT: cvt.u32.u16 %r3, %rs2;
; SM70-NEXT: shl.b32 %r4, %r3, 16;
-; SM70-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM70-NEXT: cvt.u32.u16 %r5, %rs4;
-; SM70-NEXT: shl.b32 %r6, %r5, 16;
-; SM70-NEXT: sub.rn.f32 %r7, %r6, %r4;
-; SM70-NEXT: bfe.u32 %r8, %r7, 16, 1;
-; SM70-NEXT: a...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in principle, but there are a few oddities in the test changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I wonder about is how this changes the PTX semantics of a program. If we change from loading something as a b32 to a v2.b16 will this impact the memory consistency guarantees in PTX?
cb31619
to
8412432
Compare
I'm not aware of anything in the spec that leads to different semantics. I would think since the underlying data being accessed and the alignment requirements are the same in both cases, they are essentially the "same operation" using different amounts of registers. |
8412432
to
14a5e84
Compare
1ccea45
to
bb58a83
Compare
bb58a83
to
98b7c87
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, LGTM
Fold movs into loads and stores by increasing the number of return values or operands. For example:
...becomes...