Skip to content

Commit 2ff6c05

Browse files
committed
[NVPTX] add more comments to PerformLoadCombine
1 parent 40a2396 commit 2ff6c05

File tree

1 file changed

+104
-32
lines changed

1 file changed

+104
-32
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 104 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5332,6 +5332,55 @@ getMachineMemOperandForType(const SelectionDAG &DAG,
53325332
LLT(VT));
53335333
}
53345334

5335+
// These are Combiner rules for expanding v2f32 load results when they are
5336+
// really being used as their individual f32 components. Now that v2f32 is a
5337+
// legal type for a register, LowerFormalArguments() and ReplaceLoadVector()
5338+
// will pack two f32s into a single 64-bit register, leading to ld.b64 instead
5339+
// of ld.v2.f32 or ld.v2.b64 instead of ld.v4.f32. Sometimes this is ideal if
5340+
// the results stay packed because they're passed to another instruction that
5341+
// supports packed f32s (e.g. fmul.f32x2) or (rarely) if v2f32 really is being
5342+
// reinterpreted as an i64, and then stored.
5343+
//
5344+
// Otherwise, SelectionDAG will unpack the results with a sequence of bitcasts,
5345+
// extensions, and extracts if they go through any other kind of instruction.
5346+
// This is not ideal, so we undo these patterns and rewrite the load to output
5347+
// twice as many registers: two f32s for every one i64. This preserves PTX
5348+
// codegen for programs that don't use packed f32s.
5349+
//
5350+
// Also, LowerFormalArguments() and ReplaceLoadVector() happen too early for us
5351+
// to know whether the def-use chain for a particular load will eventually
5352+
// include instructions supporting packed f32s. That is why we prefer to resolve
5353+
// this problem within DAG Combiner.
5354+
//
5355+
// This rule proceeds in three general steps:
5356+
//
5357+
// 1. Identify the pattern, by traversing the def-use chain.
5358+
// 2. Rewrite the load, by splitting each 64-bit result into two f32 registers.
5359+
// 3. Rewrite all uses of the load, including chain and glue uses.
5360+
//
5361+
// This has the effect of combining multiple instructions into a single load.
5362+
// For example:
5363+
//
5364+
// (before, ex1)
5365+
// v: v2f32 = LoadParam [p]
5366+
// f1: f32 = extractelt v, 0
5367+
// f2: f32 = extractelt v, 1
5368+
// r = add.f32 f1, f2
5369+
//
5370+
// ...or...
5371+
//
5372+
// (before, ex2)
5373+
// i: i64 = LoadParam [p]
5374+
// v: v2f32 = bitcast i
5375+
// f1: f32 = extractelt v, 0
5376+
// f2: f32 = extractelt v, 1
5377+
// r = add.f32 f1, f2
5378+
//
5379+
// ...will become...
5380+
//
5381+
// (after for both)
5382+
// vf: f32,f32 = LoadParamV2 [p]
5383+
// r = add.f32 vf:0, vf:1
53355384
static SDValue PerformLoadCombine(SDNode *N,
53365385
TargetLowering::DAGCombinerInfo &DCI,
53375386
const NVPTXSubtarget &STI) {
@@ -5351,6 +5400,7 @@ static SDValue PerformLoadCombine(SDNode *N,
53515400
return VT == MVT::i64 || VT == MVT::f32 || VT.isVector();
53525401
});
53535402

5403+
// (1) All we are doing here is looking for patterns.
53545404
SmallDenseMap<SDNode *, unsigned> ExtractElts;
53555405
SmallVector<SDNode *> ProxyRegs(OrigNumResults, nullptr);
53565406
SmallVector<std::pair<SDNode *, unsigned>> WorkList{{N, {}}};
@@ -5402,24 +5452,18 @@ static SDValue PerformLoadCombine(SDNode *N,
54025452
ProcessingInitialLoad = false;
54035453
}
54045454

5405-
// (2) If the load's value is only used as f32 elements, replace all
5406-
// extractelts with individual elements of the newly-created load. If there's
5407-
// a ProxyReg, handle that too. After this check, we'll proceed in the
5408-
// following way:
5409-
// 1. Determine which type of load to create, which will split the results
5410-
// of the original load into f32 components.
5411-
// 2. If there's a ProxyReg, split that too.
5412-
// 3. Replace all extractelts with references to the new load / proxy reg.
5413-
// 4. Replace all glue/chain references with references to the new load /
5414-
// proxy reg.
5455+
// Did we find any patterns? All patterns we're interested in end with an
5456+
// extractelt.
54155457
if (ExtractElts.empty())
54165458
return SDValue();
54175459

5460+
// (2) Now, we will decide what load to create.
5461+
54185462
// Do we have to tweak the opcode for an NVPTXISD::Load* or do we have to
54195463
// rewrite an ISD::LOAD?
54205464
std::optional<NVPTXISD::NodeType> NewOpcode;
54215465

5422-
// LoadV's are handled slightly different in ISelDAGToDAG.
5466+
// LoadV's are handled slightly different in ISelDAGToDAG. See below.
54235467
bool IsLoadV = false;
54245468
switch (N->getOpcode()) {
54255469
case NVPTXISD::LoadV2:
@@ -5434,7 +5478,15 @@ static SDValue PerformLoadCombine(SDNode *N,
54345478
break;
54355479
}
54365480

5437-
SDValue OldChain, OldGlue;
5481+
// We haven't created the new load yet, but we're saving some information
5482+
// about the old load because we will need to replace all uses of it later.
5483+
// Because our pattern is generic, we're matching ISD::LOAD and
5484+
// NVPTXISD::Load*, and we just search for the chain and glue outputs rather
5485+
// than have a case for each type of load.
5486+
const bool HaveProxyRegs =
5487+
llvm::any_of(ProxyRegs, [](const SDNode *PR) { return PR != nullptr; });
5488+
5489+
SDValue OldChain, OldGlue /* optional */;
54385490
for (unsigned I = 0, E = N->getNumValues(); I != E; ++I) {
54395491
if (N->getValueType(I) == MVT::Other)
54405492
OldChain = SDValue(N, I);
@@ -5444,7 +5496,8 @@ static SDValue PerformLoadCombine(SDNode *N,
54445496

54455497
SDValue NewLoad, NewChain, NewGlue /* (optional) */;
54465498
unsigned NumElts = 0;
5447-
if (NewOpcode) { // tweak NVPTXISD::Load* opcode
5499+
if (NewOpcode) {
5500+
// Here, we are tweaking a NVPTXISD::Load* opcode to output N*2 results.
54485501
SmallVector<EVT> VTs;
54495502

54505503
// should always be non-null after this
@@ -5485,6 +5538,15 @@ static SDValue PerformLoadCombine(SDNode *N,
54855538
if (NewGlueIdx)
54865539
NewGlue = NewLoad.getValue(*NewGlueIdx);
54875540
} else if (N->getOpcode() == ISD::LOAD) { // rewrite a load
5541+
// Here, we are lowering an ISD::LOAD to an NVPTXISD::Load*. For example:
5542+
//
5543+
// (before)
5544+
// v2f32,ch,glue = ISD::LOAD [p]
5545+
//
5546+
// ...becomes...
5547+
//
5548+
// (after)
5549+
// f32,f32,ch,glue = NVPTXISD::LoadV2 [p]
54885550
std::optional<EVT> CastToType;
54895551
EVT ResVT = N->getValueType(0);
54905552
if (ResVT == MVT::i64) {
@@ -5502,23 +5564,41 @@ static SDValue PerformLoadCombine(SDNode *N,
55025564
}
55035565
}
55045566

5567+
// If this was some other type of load we couldn't handle, we bail.
55055568
if (!NewLoad)
5506-
return SDValue(); // could not match pattern
5569+
return SDValue();
55075570

5508-
// (3) begin rewriting uses
5571+
// (3) We successfully rewrote the load. Now we must rewrite all uses of the
5572+
// old load.
55095573
SmallVector<SDValue> NewOutputsF32;
55105574

5511-
if (llvm::any_of(ProxyRegs, [](const SDNode *PR) { return PR != nullptr; })) {
5512-
// scalarize proxy regs, but first rewrite all uses of chain and glue from
5513-
// the old load to the new load
5575+
if (!HaveProxyRegs) {
5576+
// The case without proxy registers in the def-use chain is simple. Each
5577+
// extractelt is matched to an output of the new load (see calls to
5578+
// DCI.CombineTo() below).
5579+
for (unsigned I = 0, E = NumElts; I != E; ++I)
5580+
if (NewLoad->getValueType(I) == MVT::f32)
5581+
NewOutputsF32.push_back(NewLoad.getValue(I));
5582+
5583+
// replace all glue and chain nodes
5584+
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5585+
if (OldGlue)
5586+
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
5587+
} else {
5588+
// The case with proxy registers is slightly more complicated. We have to
5589+
// expand those too.
5590+
5591+
// First, rewrite all uses of chain and glue from the old load to the new
5592+
// load. This is one less thing to worry about.
55145593
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
55155594
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
55165595

5596+
// Now we will expand all the proxy registers for each output.
55175597
for (unsigned ProxyI = 0, ProxyE = ProxyRegs.size(); ProxyI != ProxyE;
55185598
++ProxyI) {
55195599
SDNode *ProxyReg = ProxyRegs[ProxyI];
55205600

5521-
// no proxy reg might mean this result is unused
5601+
// No proxy reg might mean this result is unused.
55225602
if (!ProxyReg)
55235603
continue;
55245604

@@ -5532,12 +5612,12 @@ static SDValue PerformLoadCombine(SDNode *N,
55325612
if (SDValue OldInGlue = ProxyReg->getOperand(2); OldInGlue.getNode() != N)
55335613
NewGlue = OldInGlue;
55345614

5535-
// update OldChain, OldGlue to the outputs of ProxyReg, which we will
5536-
// replace later
5615+
// Update OldChain, OldGlue to the outputs of ProxyReg, which we will
5616+
// replace later.
55375617
OldChain = SDValue(ProxyReg, 1);
55385618
OldGlue = SDValue(ProxyReg, 2);
55395619

5540-
// generate the scalar proxy regs
5620+
// Generate the scalar proxy regs.
55415621
for (unsigned I = 0, E = 2; I != E; ++I) {
55425622
SDValue ProxyRegElem = DCI.DAG.getNode(
55435623
NVPTXISD::ProxyReg, SDLoc(ProxyReg),
@@ -5552,18 +5632,10 @@ static SDValue PerformLoadCombine(SDNode *N,
55525632
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
55535633
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
55545634
}
5555-
} else {
5556-
for (unsigned I = 0, E = NumElts; I != E; ++I)
5557-
if (NewLoad->getValueType(I) == MVT::f32)
5558-
NewOutputsF32.push_back(NewLoad.getValue(I));
5559-
5560-
// replace all glue and chain nodes
5561-
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5562-
if (OldGlue)
5563-
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
55645635
}
55655636

5566-
// replace all extractelts with the new outputs
5637+
// Replace all extractelts with the new outputs. This leaves the old load and
5638+
// unpacking instructions dead.
55675639
for (auto &[Extract, Index] : ExtractElts)
55685640
DCI.CombineTo(Extract, NewOutputsF32[Index], false);
55695641

0 commit comments

Comments
 (0)