Skip to content

Handle VECREDUCE intrinsics in NVPTX backend #136253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Prince781
Copy link
Contributor

Lower VECREDUCE intrinsics to tree reductions when reassociations are allowed. And add support for min3/max3 from PTX ISA 8.8 (to be released next week).

@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Princeton Ferro (Prince781)

Changes

Lower VECREDUCE intrinsics to tree reductions when reassociations are allowed. And add support for min3/max3 from PTX ISA 8.8 (to be released next week).


Patch is 94.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136253.diff

6 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTX.td (+8-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+230)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+6)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+54)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+2)
  • (added) llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll (+1908)
diff --git a/llvm/lib/Target/NVPTX/NVPTX.td b/llvm/lib/Target/NVPTX/NVPTX.td
index 5467ae011a208..d4dc278cfa648 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.td
+++ b/llvm/lib/Target/NVPTX/NVPTX.td
@@ -36,17 +36,19 @@ class FeaturePTX<int version>:
 
 foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53,
               60, 61, 62, 70, 72, 75, 80, 86, 87,
-              89, 90, 100, 101, 120] in
+              89, 90, 100, 101, 103, 120, 121] in
   def SM#sm: FeatureSM<""#sm, !mul(sm, 10)>;
 
 def SM90a: FeatureSM<"90a", 901>;
 def SM100a: FeatureSM<"100a", 1001>;
 def SM101a: FeatureSM<"101a", 1011>;
+def SM103a: FeatureSM<"103a", 1031>;
 def SM120a: FeatureSM<"120a", 1201>;
+def SM121a: FeatureSM<"121a", 1211>;
 
 foreach version = [32, 40, 41, 42, 43, 50, 60, 61, 62, 63, 64, 65,
                    70, 71, 72, 73, 74, 75, 76, 77, 78,
-                   80, 81, 82, 83, 84, 85, 86, 87] in
+                   80, 81, 82, 83, 84, 85, 86, 87, 88] in
   def PTX#version: FeaturePTX<version>;
 
 //===----------------------------------------------------------------------===//
@@ -81,8 +83,12 @@ def : Proc<"sm_100", [SM100, PTX86]>;
 def : Proc<"sm_100a", [SM100a, PTX86]>;
 def : Proc<"sm_101", [SM101, PTX86]>;
 def : Proc<"sm_101a", [SM101a, PTX86]>;
+def : Proc<"sm_103", [SM103, PTX88]>;
+def : Proc<"sm_103a", [SM103a, PTX88]>;
 def : Proc<"sm_120", [SM120, PTX87]>;
 def : Proc<"sm_120a", [SM120a, PTX87]>;
+def : Proc<"sm_121", [SM121, PTX88]>;
+def : Proc<"sm_121a", [SM121a, PTX88]>;
 
 def NVPTXInstrInfo : InstrInfo {
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9bde2a976e164..3a14aa47e5065 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -831,6 +831,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   if (STI.allowFP16Math() || STI.hasBF16Math())
     setTargetDAGCombine(ISD::SETCC);
 
+  // Vector reduction operations. These are transformed into a tree evaluation
+  // of nodes which may initially be illegal.
+  for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
+    MVT EltVT = VT.getVectorElementType();
+    if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
+        EltVT == MVT::f64) {
+      setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
+                          ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
+                          ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
+                          ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
+                         VT, Custom);
+    } else if (EltVT.isScalarInteger()) {
+      setOperationAction(
+          {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
+           ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
+           ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
+          VT, Custom);
+    }
+  }
+
   // Promote fp16 arithmetic if fp16 hardware isn't available or the
   // user passed --nvptx-no-fp16-math. The flag is useful because,
   // although sm_53+ GPUs have some sort of FP16 support in
@@ -1082,6 +1102,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::BFI)
     MAKE_CASE(NVPTXISD::PRMT)
     MAKE_CASE(NVPTXISD::FCOPYSIGN)
+    MAKE_CASE(NVPTXISD::FMAXNUM3)
+    MAKE_CASE(NVPTXISD::FMINNUM3)
+    MAKE_CASE(NVPTXISD::FMAXIMUM3)
+    MAKE_CASE(NVPTXISD::FMINIMUM3)
     MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
     MAKE_CASE(NVPTXISD::STACKRESTORE)
     MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2136,6 +2160,194 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
+/// A generic routine for constructing a tree reduction on a vector operand.
+/// This method differs from iterative splitting in DAGTypeLegalizer by
+/// progressively grouping elements bottom-up.
+static SDValue BuildTreeReduction(
+    const SmallVector<SDValue> &Elements, EVT EltTy,
+    ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
+    const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
+  // now build the computation graph in place at each level
+  SmallVector<SDValue> Level = Elements;
+  unsigned OpIdx = 0;
+  while (Level.size() > 1) {
+    const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
+
+    // partially reduce all elements in level
+    SmallVector<SDValue> ReducedLevel;
+    unsigned I = 0, E = Level.size();
+    for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
+      // Reduce elements in groups of [DefaultGroupSize], as much as possible.
+      ReducedLevel.push_back(DAG.getNode(
+          DefaultScalarOp, DL, EltTy,
+          ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
+    }
+
+    if (I < E) {
+      if (ReducedLevel.empty()) {
+        // The current operator requires more inputs than there are operands at
+        // this level. Pick a smaller operator and retry.
+        ++OpIdx;
+        assert(OpIdx < Ops.size() && "no smaller operators for reduction");
+        continue;
+      }
+
+      // Otherwise, we just have a remainder, which we push to the next level.
+      for (; I < E; ++I)
+        ReducedLevel.push_back(Level[I]);
+    }
+    Level = ReducedLevel;
+  }
+
+  return *Level.begin();
+}
+
+/// Lower reductions to either a sequence of operations or a tree if
+/// reassociations are allowed. This method will use larger operations like
+/// max3/min3 when the target supports them.
+SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  const SDNodeFlags Flags = Op->getFlags();
+  SDValue Vector;
+  SDValue Accumulator;
+  if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
+      Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
+    // special case with accumulator as first arg
+    Accumulator = Op.getOperand(0);
+    Vector = Op.getOperand(1);
+  } else {
+    // default case
+    Vector = Op.getOperand(0);
+  }
+  EVT EltTy = Vector.getValueType().getVectorElementType();
+  const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
+                             STI.getPTXVersion() >= 88;
+
+  // A list of SDNode opcodes with equivalent semantics, sorted descending by
+  // number of inputs they take.
+  SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
+  bool IsReassociatable;
+
+  switch (Op->getOpcode()) {
+  case ISD::VECREDUCE_FADD:
+  case ISD::VECREDUCE_SEQ_FADD:
+    ScalarOps = {{ISD::FADD, 2}};
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMUL:
+  case ISD::VECREDUCE_SEQ_FMUL:
+    ScalarOps = {{ISD::FMUL, 2}};
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMAX:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
+    ScalarOps.push_back({ISD::FMAXNUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMIN:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
+    ScalarOps.push_back({ISD::FMINNUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMAXIMUM:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
+    ScalarOps.push_back({ISD::FMAXIMUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMINIMUM:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
+    ScalarOps.push_back({ISD::FMINIMUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_ADD:
+    ScalarOps = {{ISD::ADD, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_MUL:
+    ScalarOps = {{ISD::MUL, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_UMAX:
+    ScalarOps = {{ISD::UMAX, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_UMIN:
+    ScalarOps = {{ISD::UMIN, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_SMAX:
+    ScalarOps = {{ISD::SMAX, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_SMIN:
+    ScalarOps = {{ISD::SMIN, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_AND:
+    ScalarOps = {{ISD::AND, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_OR:
+    ScalarOps = {{ISD::OR, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_XOR:
+    ScalarOps = {{ISD::XOR, 2}};
+    IsReassociatable = true;
+    break;
+  default:
+    llvm_unreachable("unhandled vecreduce operation");
+  }
+
+  EVT VectorTy = Vector.getValueType();
+  const unsigned NumElts = VectorTy.getVectorNumElements();
+
+  // scalarize vector
+  SmallVector<SDValue> Elements(NumElts);
+  for (unsigned I = 0, E = NumElts; I != E; ++I) {
+    Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
+                              DAG.getConstant(I, DL, MVT::i64));
+  }
+
+  // Lower to tree reduction.
+  if (IsReassociatable || Flags.hasAllowReassociation()) {
+    // we don't expect an accumulator for reassociatable vector reduction ops
+    assert(!Accumulator && "unexpected accumulator");
+    return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
+  }
+
+  // Lower to sequential reduction.
+  for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
+    assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
+    const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
+
+    if (!Accumulator) {
+      if (I + DefaultGroupSize <= NumElts) {
+        Accumulator = DAG.getNode(
+            DefaultScalarOp, DL, EltTy,
+            ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
+        I += DefaultGroupSize;
+      }
+    }
+
+    if (Accumulator) {
+      for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
+        SmallVector<SDValue> Operands = {Accumulator};
+        for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
+          Operands.push_back(Elements[I + K]);
+        Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
+      }
+    }
+  }
+
+  return Accumulator;
+}
+
 SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
   // Handle bitcasting from v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
@@ -2879,6 +3091,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVECTOR_SHUFFLE(Op, DAG);
   case ISD::CONCAT_VECTORS:
     return LowerCONCAT_VECTORS(Op, DAG);
+  case ISD::VECREDUCE_FADD:
+  case ISD::VECREDUCE_FMUL:
+  case ISD::VECREDUCE_SEQ_FADD:
+  case ISD::VECREDUCE_SEQ_FMUL:
+  case ISD::VECREDUCE_FMAX:
+  case ISD::VECREDUCE_FMIN:
+  case ISD::VECREDUCE_FMAXIMUM:
+  case ISD::VECREDUCE_FMINIMUM:
+  case ISD::VECREDUCE_ADD:
+  case ISD::VECREDUCE_MUL:
+  case ISD::VECREDUCE_UMAX:
+  case ISD::VECREDUCE_UMIN:
+  case ISD::VECREDUCE_SMAX:
+  case ISD::VECREDUCE_SMIN:
+  case ISD::VECREDUCE_AND:
+  case ISD::VECREDUCE_OR:
+  case ISD::VECREDUCE_XOR:
+    return LowerVECREDUCE(Op, DAG);
   case ISD::STORE:
     return LowerSTORE(Op, DAG);
   case ISD::LOAD:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index dd90746f6d9d6..0880f3e8edfe8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
   UNPACK_VECTOR,
 
   FCOPYSIGN,
+  FMAXNUM3,
+  FMINNUM3,
+  FMAXIMUM3,
+  FMINIMUM3,
+
   DYNAMIC_STACKALLOC,
   STACKRESTORE,
   STACKSAVE,
@@ -296,6 +301,7 @@ class NVPTXTargetLowering : public TargetLowering {
 
   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 16b489afddf5c..721b578ba72ee 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -368,6 +368,46 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
 }
 
+// 3-input min/max (sm_100+) for f32 only
+multiclass FMINIMUMMAXIMUM3<string OpcStr, SDNode OpNode> {
+   def f32rrr_ftz :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
+               !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
+               Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
+   def f32rri_ftz :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
+               !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
+               Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
+   def f32rii_ftz :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, f32imm:$b, f32imm:$c),
+               !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
+               Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
+   def f32rrr :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
+               !strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
+               Requires<[hasPTX<88>, hasSM<100>]>;
+   def f32rri :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
+               !strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, Float32Regs:$b, fpimm:$c))]>,
+               Requires<[hasPTX<88>, hasSM<100>]>;
+   def f32rii :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, f32imm:$b, f32imm:$c),
+               !strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
+               Requires<[hasPTX<88>, hasSM<100>]>;
+}
+
 // Template for instructions which take three FP args.  The
 // instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
 //
@@ -1101,6 +1141,20 @@ defm FMAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
 defm FMINNAN : FMINIMUMMAXIMUM<"min.NaN", /* NaN */ true, fminimum>;
 defm FMAXNAN : FMINIMUMMAXIMUM<"max.NaN", /* NaN */ true, fmaximum>;
 
+def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
+                            [SDNPCommutative, SDNPAssociative]>;
+def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
+                             [SDNPCommutative, SDNPAssociative]>;
+def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
+                             [SDNPCommutative, SDNPAssociative]>;
+def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
+                             [SDNPCommutative, SDNPAssociative]>;
+
+defm FMIN3 : FMINIMUMMAXIMUM3<"min", nvptx_fminnum3>;
+defm FMAX3 : FMINIMUMMAXIMUM3<"max", nvptx_fmaxnum3>;
+defm FMINNAN3 : FMINIMUMMAXIMUM3<"min.NaN", nvptx_fminimum3>;
+defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max.NaN", nvptx_fmaximum3>;
+
 defm FABS  : F2<"abs", fabs>;
 defm FNEG  : F2<"neg", fneg>;
 defm FABS_H: F2_Support_Half<"abs", fabs>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 9e77f628da7a7..2cc81a064152f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -83,6 +83,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
   }
   unsigned getMinVectorRegisterBitWidth() const { return 32; }
 
+  bool shouldExpandReduction(const IntrinsicInst *II) const { return false; }
+
   // We don't want to prevent inlining because of target-cpu and -features
   // attributes that were added to newer versions of LLVM/Clang: There are
   // no incompatible functions in PTX, ptxas will throw errors in such cases.
diff --git a/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll b/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll
new file mode 100644
index 0000000000000..a9101ba3ca651
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll
@@ -0,0 +1,1908 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --extra_scrub --version 5
+; RUN: llc < %s -mcpu=sm_80 -mattr=+ptx70 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | FileCheck -check-prefixes CHECK,CHECK-SM80 %s
+; RUN: %if ptxas-12.9 %{ llc < %s -mcpu=sm_80 -mattr=+ptx70 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | %ptxas-verify -arch=sm_80 %}
+; RUN: llc < %s -mcpu=sm_100 -mattr=+ptx88 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | FileCheck -check-prefixes CHECK,CHECK-SM100 %s
+; RUN: %if ptxas-12.9 %{ llc < %s -mcpu=sm_100 -mattr=+ptx88 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | %ptxas-verify -arch=sm_100 %}
+target triple = "nvptx64-nvidia-cuda"
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+; Check straight line reduction.
+define half @reduce_fadd_half(<8 x half> %in) {
+; CHECK-LABEL: reduce_fadd_half(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<18>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_param_0];
+; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r1;
+; CHECK-NEXT:    mov.b16 %rs3, 0x0000;
+; CHECK-NEXT:    add.rn.f16 %rs4, %rs1, %rs3;
+; CHECK-NEXT:    add.rn.f16 %rs5, %rs4, %rs2;
+; CHECK-NEXT:    mov.b32 {%rs6, %rs7}, %r2;
+; CHECK-NEXT:    add.rn.f16 %rs8, %rs5, %rs6;
+; CHECK-NEXT:    add.rn.f16 %rs9, %rs8, %rs7;
+; CHECK-NEXT:    mov.b32 {%rs10, %rs11}, %r3;
+; CHECK-NEXT:    add.rn.f16 %rs12, %rs9, %rs10;
+; CHECK-NEXT:    add.rn.f16 %rs13, %rs12, %rs11;
+; CHECK-NEXT:    mov.b32 {%rs14, %rs15}, %r4;
+; CHECK-NEXT:    add.rn.f16 %rs16, %rs13, %rs14;
+; CHECK-NEXT:    add.rn.f16 %rs17, %rs16, %rs15;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs17;
+; CHECK-NEXT:    ret;
+  %res = call half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
+  ret half %res
+}
+
+; Check tree reduction.
+define half @reduce_fadd_half_reassoc(<8 x half> %in) {
+; CHECK-LABEL: reduce_fadd_half_reassoc(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<18>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_reassoc_param_0];
+; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r4;
+; CHECK-NEXT:    add.rn.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT:    mov.b32 {%rs4, %rs5}, %r3;
+; CHECK-NEXT:    add.rn.f16 %rs6, %rs4, %rs5;
+; CHECK-NEXT:    add.rn.f16 %rs7, %rs6, %rs3;
+; CHECK-NEXT:    mov.b32 {%rs8, %rs9}, %r2;
+; CHECK-NEXT:    add.rn.f16 %rs10, %rs8, %rs9;
+; CHECK-NEXT:    mov.b32 {%rs11, %rs12}, %r1;
+; CHECK-NEXT:    add.rn.f16 %rs13, %rs11, %rs12;
+; CHECK-NEXT:    add.rn.f16 %rs14, %rs13, %rs10;
+; CHECK-NEXT:    add.rn.f16 %rs15, %rs14, %rs7;
+; CHECK-NEXT:    mov.b16 %rs16, 0x0000;
+; CHECK-NEXT:    add.rn.f16 %rs17, %rs15, %rs16;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs17;
+; CHECK-NEXT:    ret;
+  %res = call reassoc half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
+  ret half %res
+}
+
+; Check tree reduction with non-power of 2 size.
+define half @reduce_fadd_half_reassoc_nonpow2(<7 x half> %in) {
+; CHECK-LABEL: reduce_fadd_half_reassoc_nonpow2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<16>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [reduce_fadd_half_reassoc_nonpow2_param_0+8];
+; CHECK-NEXT:    mov.b32 {%rs5, %rs6}, %r1;
+; CHECK-NEXT:    ld.param.b16 %rs7, [reduce_fadd_half_reassoc_nonpow2_param_0+12];
+; CHECK-NEXT:    ld.param....
[truncated]

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch from 1f1348c to 4e481f8 Compare April 18, 2025 04:42
Comment on lines 446 to 451
; CHECK-SM100-NEXT: ld.param.v4.f32 {%f5, %f6, %f7, %f8}, [reduce_fmax_float_reassoc_param_0+16];
; CHECK-SM100-NEXT: ld.param.v4.f32 {%f1, %f2, %f3, %f4}, [reduce_fmax_float_reassoc_param_0];
; CHECK-SM100-NEXT: max.f32 %f9, %f4, %f5, %f6;
; CHECK-SM100-NEXT: max.f32 %f10, %f1, %f2, %f3;
; CHECK-SM100-NEXT: max.f32 %f11, %f10, %f9, %f7;
; CHECK-SM100-NEXT: max.f32 %f12, %f11, %f8;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good example of tree reduction being split across multiple loads that may arrive at different times:

max2(
  max3(
    max3(f4,f5,f6),
    max3(f1,f2,f3).
    f7),
  f8)

I wonder if there would be an observable performance difference vs:

max3(
  max3(f1,f2,f3),
  max3(f5,f6,f7)
  max2(f4, f8)
)

It would potentially have one instruction shorted data dependency chain.

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch 2 times, most recently from 1ea7992 to c895233 Compare April 23, 2025 04:43
@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch from c895233 to b13ed82 Compare May 22, 2025 11:10
@Prince781
Copy link
Contributor Author

After some thought, it seems all three patterns of lowering a reduction have their tradeoffs and each is preferred in some circumstance. I'd like to propose keeping the default lowering (shuffle reduction), but add a way for clients to override this with metadata. I'm thinking about something like !reduce.tree on the reduction call. Before I proceed, I'd like to know if this is a plausible idea and if I should propose it in an RFC. Thanks!

@Prince781
Copy link
Contributor Author

Prince781 commented Jun 11, 2025

Proposed use:

; default (shuffle reduction for f16x2 and f32x2)
define float @reduce_fadd_reassoc(<16 x float> %in) {
  %res = call reassoc float @llvm.vector.reduce.fadd(float 0.0, <16 x float> %in)
  ret float %res
}

; default (tree reduction on SM100 for fmin3; shuffle reduction otherwise)
define float @reduce_fmin_reassoc(<16 x float> %in) {
  %res = call reassoc float @llvm.vector.reduce.fmin(float 0.0, <16 x float> %in)
  ret float %res
}

; force tree for reassoc float
define float @reduce_fadd_reassoc(<16 x float> %in) {
  %res = call reassoc float @llvm.vector.reduce.fadd(float 0.0, <16 x float> %in) !reduce.tree
  ret float %res
}

; force shuffle
define float @reduce_fmin_reassoc(<16 x float> %in) {
  %res = call reassoc float @llvm.vector.reduce.fmin(float 0.0, <16 x float> %in) !reduce.shuffle
  ret float %res
}

; force sequential
define i32 @reduce_umin(<16 x i32> %in) {
  %res = call i32 @llvm.vector.reduce.umin(<16 x i32> %in) !reduce.sequential
  ret i32 %res
}

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch 2 times, most recently from 24925dc to a22c141 Compare June 19, 2025 04:49
Copy link

github-actions bot commented Jun 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Prince781
Copy link
Contributor Author

Actually, I'm going to move the proposal of metadata / attributes affecting lowering decisions to another discussion in a future PR or RFC.

I've just updated the PR. We'll use the default shuffle reduction in SelectionDAG when packed ops are available on the target for the element type (ex: v2f16, v2bf16, v2i16). Otherwise we'll use tree or sequential reductions, depending on the fast-math option or reassoc flag, and whether there are special operations available like fmin3 / fmax3.

This allows us to keep using packed operations like max.f16x2 (which use less registers) while switching to tree reduction in every other case where we can reassociate operands. We're deciding this based on the target's support for the element type with the operation, so I think it makes sense to handle these intrinsics at the SelectionDAG level and not at the IR level inside shouldExpandReductions().

I also notice that because we now fallback to the shuffle reduction generated by SelectionDAG instead of ExpandReductions, the codegen is cleaner and the last packed operation is scalarized. So this PR may supersede #143943

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch 2 times, most recently from 76acdea to ca244ab Compare June 19, 2025 05:11
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in
PTX 8.8.

This method of tree reduction has a few benefits over the default in
DAGTypeLegalizer:
- The default shuffle reduction progressively halves and partially
  reduces the vector down until we reach a single element. This produces
  a sequence of operations that combine disparate elements of the
  vector. For example, `vecreduce_fadd <4 x f32><a b c d>` will give `(a
  + c) + (b + d)`, whereas the tree reduction produces (a + b) + (c + d)
  by grouping nearby elements together first. Both use the same number
  of registers, but the shuffle reduction has longer live ranges. The
  same example is graphed below. Note we hold onto 3 registers for 2
  cycles in the shuffle reduction and 1 cycle in tree reduction.

  (shuffle reduction)

  PTX:
  %r1 = add.f32 a, c
  %r2 = add.f32 b, d
  %r3 = add.f32 %r1, %r3

  Pipeline:
  cycles ---->
  |      1       |      2       |         3          |         4          |          5         |          6             |
  | a = load.f32 | b = load.f32 | c = load.f32       | d = load.f32       |                    |                        |
  |              |              |                    | %r1 = add.f32 a, c | %r2 = add.f32 b, d | %r3 = add.f32 %r1, %r2 |

  live regs ---->
  | a       [1R] | a b     [2R] | a b c         [3R] | b d %r1       [3R] | %r1 %r2       [2R] | %r3               [1R] |

  (tree reduction)

  PTX:
  %r1 = add.f32 a, b
  %r2 = add.f32 c, d
  %r3 = add.f32 %r1, %r2

  Pipeline:
  cycles ---->
  |      1       |      2       |         3          |         4          |          5         |          6             |
  | a = load.f32 | b = load.f32 | c = load.f32       | d = load.f32       |                    |                        |
  |              |              | %r1 = add.f32 a, b |                    | %r2 = add.f32 c, d | %r3 = add.f32 %r1, %r2 |

  live regs ---->
  | a       [1R] | a b     [2R] | c %r1         [2R] | c %r1 d       [3R] | %r1 %r2       [2R] | %r3               [1R] |

- The shuffle reduction cannot easily support fmax3/fmin3 because it
  progressively halves the input vector.
- Faster compile time. Happens in one pass over the intrinsic, rather
  than O(N) passes if iteratively splitting the vector operands.
@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch from ca244ab to 36e8acd Compare June 19, 2025 05:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants