@@ -5332,6 +5332,55 @@ getMachineMemOperandForType(const SelectionDAG &DAG,
5332
5332
LLT (VT));
5333
5333
}
5334
5334
5335
+ // These are Combiner rules for expanding v2f32 load results when they are
5336
+ // really being used as their individual f32 components. Now that v2f32 is a
5337
+ // legal type for a register, LowerFormalArguments() and ReplaceLoadVector()
5338
+ // will pack two f32s into a single 64-bit register, leading to ld.b64 instead
5339
+ // of ld.v2.f32 or ld.v2.b64 instead of ld.v4.f32. Sometimes this is ideal if
5340
+ // the results stay packed because they're passed to another instruction that
5341
+ // supports packed f32s (e.g. fmul.f32x2) or (rarely) if v2f32 really is being
5342
+ // reinterpreted as an i64, and then stored.
5343
+ //
5344
+ // Otherwise, SelectionDAG will unpack the results with a sequence of bitcasts,
5345
+ // extensions, and extracts if they go through any other kind of instruction.
5346
+ // This is not ideal, so we undo these patterns and rewrite the load to output
5347
+ // twice as many registers: two f32s for every one i64. This preserves PTX
5348
+ // codegen for programs that don't use packed f32s.
5349
+ //
5350
+ // Also, LowerFormalArguments() and ReplaceLoadVector() happen too early for us
5351
+ // to know whether the def-use chain for a particular load will eventually
5352
+ // include instructions supporting packed f32s. That is why we prefer to resolve
5353
+ // this problem within DAG Combiner.
5354
+ //
5355
+ // This rule proceeds in three general steps:
5356
+ //
5357
+ // 1. Identify the pattern, by traversing the def-use chain.
5358
+ // 2. Rewrite the load, by splitting each 64-bit result into two f32 registers.
5359
+ // 3. Rewrite all uses of the load, including chain and glue uses.
5360
+ //
5361
+ // This has the effect of combining multiple instructions into a single load.
5362
+ // For example:
5363
+ //
5364
+ // (before, ex1)
5365
+ // v: v2f32 = LoadParam [p]
5366
+ // f1: f32 = extractelt v, 0
5367
+ // f2: f32 = extractelt v, 1
5368
+ // r = add.f32 f1, f2
5369
+ //
5370
+ // ...or...
5371
+ //
5372
+ // (before, ex2)
5373
+ // i: i64 = LoadParam [p]
5374
+ // v: v2f32 = bitcast i
5375
+ // f1: f32 = extractelt v, 0
5376
+ // f2: f32 = extractelt v, 1
5377
+ // r = add.f32 f1, f2
5378
+ //
5379
+ // ...will become...
5380
+ //
5381
+ // (after for both)
5382
+ // vf: f32,f32 = LoadParamV2 [p]
5383
+ // r = add.f32 vf:0, vf:1
5335
5384
static SDValue PerformLoadCombine (SDNode *N,
5336
5385
TargetLowering::DAGCombinerInfo &DCI,
5337
5386
const NVPTXSubtarget &STI) {
@@ -5351,6 +5400,7 @@ static SDValue PerformLoadCombine(SDNode *N,
5351
5400
return VT == MVT::i64 || VT == MVT::f32 || VT.isVector ();
5352
5401
});
5353
5402
5403
+ // (1) All we are doing here is looking for patterns.
5354
5404
SmallDenseMap<SDNode *, unsigned > ExtractElts;
5355
5405
SmallVector<SDNode *> ProxyRegs (OrigNumResults, nullptr );
5356
5406
SmallVector<std::pair<SDNode *, unsigned >> WorkList{{N, {}}};
@@ -5402,24 +5452,18 @@ static SDValue PerformLoadCombine(SDNode *N,
5402
5452
ProcessingInitialLoad = false ;
5403
5453
}
5404
5454
5405
- // (2) If the load's value is only used as f32 elements, replace all
5406
- // extractelts with individual elements of the newly-created load. If there's
5407
- // a ProxyReg, handle that too. After this check, we'll proceed in the
5408
- // following way:
5409
- // 1. Determine which type of load to create, which will split the results
5410
- // of the original load into f32 components.
5411
- // 2. If there's a ProxyReg, split that too.
5412
- // 3. Replace all extractelts with references to the new load / proxy reg.
5413
- // 4. Replace all glue/chain references with references to the new load /
5414
- // proxy reg.
5455
+ // Did we find any patterns? All patterns we're interested in end with an
5456
+ // extractelt.
5415
5457
if (ExtractElts.empty ())
5416
5458
return SDValue ();
5417
5459
5460
+ // (2) Now, we will decide what load to create.
5461
+
5418
5462
// Do we have to tweak the opcode for an NVPTXISD::Load* or do we have to
5419
5463
// rewrite an ISD::LOAD?
5420
5464
std::optional<NVPTXISD::NodeType> NewOpcode;
5421
5465
5422
- // LoadV's are handled slightly different in ISelDAGToDAG.
5466
+ // LoadV's are handled slightly different in ISelDAGToDAG. See below.
5423
5467
bool IsLoadV = false ;
5424
5468
switch (N->getOpcode ()) {
5425
5469
case NVPTXISD::LoadV2:
@@ -5434,7 +5478,15 @@ static SDValue PerformLoadCombine(SDNode *N,
5434
5478
break ;
5435
5479
}
5436
5480
5437
- SDValue OldChain, OldGlue;
5481
+ // We haven't created the new load yet, but we're saving some information
5482
+ // about the old load because we will need to replace all uses of it later.
5483
+ // Because our pattern is generic, we're matching ISD::LOAD and
5484
+ // NVPTXISD::Load*, and we just search for the chain and glue outputs rather
5485
+ // than have a case for each type of load.
5486
+ const bool HaveProxyRegs =
5487
+ llvm::any_of (ProxyRegs, [](const SDNode *PR) { return PR != nullptr ; });
5488
+
5489
+ SDValue OldChain, OldGlue /* optional */ ;
5438
5490
for (unsigned I = 0 , E = N->getNumValues (); I != E; ++I) {
5439
5491
if (N->getValueType (I) == MVT::Other)
5440
5492
OldChain = SDValue (N, I);
@@ -5444,7 +5496,8 @@ static SDValue PerformLoadCombine(SDNode *N,
5444
5496
5445
5497
SDValue NewLoad, NewChain, NewGlue /* (optional) */ ;
5446
5498
unsigned NumElts = 0 ;
5447
- if (NewOpcode) { // tweak NVPTXISD::Load* opcode
5499
+ if (NewOpcode) {
5500
+ // Here, we are tweaking a NVPTXISD::Load* opcode to output N*2 results.
5448
5501
SmallVector<EVT> VTs;
5449
5502
5450
5503
// should always be non-null after this
@@ -5485,6 +5538,15 @@ static SDValue PerformLoadCombine(SDNode *N,
5485
5538
if (NewGlueIdx)
5486
5539
NewGlue = NewLoad.getValue (*NewGlueIdx);
5487
5540
} else if (N->getOpcode () == ISD::LOAD) { // rewrite a load
5541
+ // Here, we are lowering an ISD::LOAD to an NVPTXISD::Load*. For example:
5542
+ //
5543
+ // (before)
5544
+ // v2f32,ch,glue = ISD::LOAD [p]
5545
+ //
5546
+ // ...becomes...
5547
+ //
5548
+ // (after)
5549
+ // f32,f32,ch,glue = NVPTXISD::LoadV2 [p]
5488
5550
std::optional<EVT> CastToType;
5489
5551
EVT ResVT = N->getValueType (0 );
5490
5552
if (ResVT == MVT::i64 ) {
@@ -5502,23 +5564,41 @@ static SDValue PerformLoadCombine(SDNode *N,
5502
5564
}
5503
5565
}
5504
5566
5567
+ // If this was some other type of load we couldn't handle, we bail.
5505
5568
if (!NewLoad)
5506
- return SDValue (); // could not match pattern
5569
+ return SDValue ();
5507
5570
5508
- // (3) begin rewriting uses
5571
+ // (3) We successfully rewrote the load. Now we must rewrite all uses of the
5572
+ // old load.
5509
5573
SmallVector<SDValue> NewOutputsF32;
5510
5574
5511
- if (llvm::any_of (ProxyRegs, [](const SDNode *PR) { return PR != nullptr ; })) {
5512
- // scalarize proxy regs, but first rewrite all uses of chain and glue from
5513
- // the old load to the new load
5575
+ if (!HaveProxyRegs) {
5576
+ // The case without proxy registers in the def-use chain is simple. Each
5577
+ // extractelt is matched to an output of the new load (see calls to
5578
+ // DCI.CombineTo() below).
5579
+ for (unsigned I = 0 , E = NumElts; I != E; ++I)
5580
+ if (NewLoad->getValueType (I) == MVT::f32 )
5581
+ NewOutputsF32.push_back (NewLoad.getValue (I));
5582
+
5583
+ // replace all glue and chain nodes
5584
+ DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5585
+ if (OldGlue)
5586
+ DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5587
+ } else {
5588
+ // The case with proxy registers is slightly more complicated. We have to
5589
+ // expand those too.
5590
+
5591
+ // First, rewrite all uses of chain and glue from the old load to the new
5592
+ // load. This is one less thing to worry about.
5514
5593
DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5515
5594
DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5516
5595
5596
+ // Now we will expand all the proxy registers for each output.
5517
5597
for (unsigned ProxyI = 0 , ProxyE = ProxyRegs.size (); ProxyI != ProxyE;
5518
5598
++ProxyI) {
5519
5599
SDNode *ProxyReg = ProxyRegs[ProxyI];
5520
5600
5521
- // no proxy reg might mean this result is unused
5601
+ // No proxy reg might mean this result is unused.
5522
5602
if (!ProxyReg)
5523
5603
continue ;
5524
5604
@@ -5532,12 +5612,12 @@ static SDValue PerformLoadCombine(SDNode *N,
5532
5612
if (SDValue OldInGlue = ProxyReg->getOperand (2 ); OldInGlue.getNode () != N)
5533
5613
NewGlue = OldInGlue;
5534
5614
5535
- // update OldChain, OldGlue to the outputs of ProxyReg, which we will
5536
- // replace later
5615
+ // Update OldChain, OldGlue to the outputs of ProxyReg, which we will
5616
+ // replace later.
5537
5617
OldChain = SDValue (ProxyReg, 1 );
5538
5618
OldGlue = SDValue (ProxyReg, 2 );
5539
5619
5540
- // generate the scalar proxy regs
5620
+ // Generate the scalar proxy regs.
5541
5621
for (unsigned I = 0 , E = 2 ; I != E; ++I) {
5542
5622
SDValue ProxyRegElem = DCI.DAG .getNode (
5543
5623
NVPTXISD::ProxyReg, SDLoc (ProxyReg),
@@ -5552,18 +5632,10 @@ static SDValue PerformLoadCombine(SDNode *N,
5552
5632
DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5553
5633
DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5554
5634
}
5555
- } else {
5556
- for (unsigned I = 0 , E = NumElts; I != E; ++I)
5557
- if (NewLoad->getValueType (I) == MVT::f32 )
5558
- NewOutputsF32.push_back (NewLoad.getValue (I));
5559
-
5560
- // replace all glue and chain nodes
5561
- DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5562
- if (OldGlue)
5563
- DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5564
5635
}
5565
5636
5566
- // replace all extractelts with the new outputs
5637
+ // Replace all extractelts with the new outputs. This leaves the old load and
5638
+ // unpacking instructions dead.
5567
5639
for (auto &[Extract, Index] : ExtractElts)
5568
5640
DCI.CombineTo (Extract, NewOutputsF32[Index], false );
5569
5641
0 commit comments