Skip to content

Commit d06b3e3

Browse files
authored
[NVPTX] improve lowering for common byte-extraction operations. (#66945)
Some critical code paths we have depend on efficient byte extraction from data loaded as integers. By default LLVM tries to extract bytes by storing/loading from stack, which is very inefficient on GPU.
1 parent 2ab31b6 commit d06b3e3

File tree

3 files changed

+166
-11
lines changed

3 files changed

+166
-11
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llvm/ADT/SmallVector.h"
2424
#include "llvm/ADT/StringRef.h"
2525
#include "llvm/CodeGen/Analysis.h"
26+
#include "llvm/CodeGen/ISDOpcodes.h"
2627
#include "llvm/CodeGen/MachineFunction.h"
2728
#include "llvm/CodeGen/MachineMemOperand.h"
2829
#include "llvm/CodeGen/MachineValueType.h"
@@ -672,7 +673,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
672673

673674
// We have some custom DAG combine patterns for these nodes
674675
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
675-
ISD::SREM, ISD::UREM});
676+
ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT});
676677

677678
// setcc for f16x2 and bf16x2 needs special handling to prevent
678679
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5252,6 +5253,47 @@ static SDValue PerformSETCCCombine(SDNode *N,
52525253
CCNode.getValue(1));
52535254
}
52545255

5256+
static SDValue PerformEXTRACTCombine(SDNode *N,
5257+
TargetLowering::DAGCombinerInfo &DCI) {
5258+
SDValue Vector = N->getOperand(0);
5259+
EVT VectorVT = Vector.getValueType();
5260+
if (Vector->getOpcode() == ISD::LOAD && VectorVT.isSimple() &&
5261+
IsPTXVectorType(VectorVT.getSimpleVT()))
5262+
return SDValue(); // Native vector loads already combine nicely w/
5263+
// extract_vector_elt.
5264+
// Don't mess with singletons or v2*16 types, we already handle them OK.
5265+
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT))
5266+
return SDValue();
5267+
5268+
uint64_t VectorBits = VectorVT.getSizeInBits();
5269+
// We only handle the types we can extract in-register.
5270+
if (!(VectorBits == 16 || VectorBits == 32 || VectorBits == 64))
5271+
return SDValue();
5272+
5273+
ConstantSDNode *Index = dyn_cast<ConstantSDNode>(N->getOperand(1));
5274+
// Index == 0 is handled by generic DAG combiner.
5275+
if (!Index || Index->getZExtValue() == 0)
5276+
return SDValue();
5277+
5278+
SDLoc DL(N);
5279+
5280+
MVT IVT = MVT::getIntegerVT(VectorBits);
5281+
EVT EltVT = VectorVT.getVectorElementType();
5282+
EVT EltIVT = EltVT.changeTypeToInteger();
5283+
uint64_t EltBits = EltVT.getScalarSizeInBits();
5284+
5285+
SDValue Result = DCI.DAG.getNode(
5286+
ISD::TRUNCATE, DL, EltIVT,
5287+
DCI.DAG.getNode(
5288+
ISD::SRA, DL, IVT, DCI.DAG.getNode(ISD::BITCAST, DL, IVT, Vector),
5289+
DCI.DAG.getConstant(Index->getZExtValue() * EltBits, DL, IVT)));
5290+
5291+
// If element has non-integer type, bitcast it back to the expected type.
5292+
if (EltVT != EltIVT)
5293+
Result = DCI.DAG.getNode(ISD::BITCAST, DL, EltVT, Result);
5294+
return Result;
5295+
}
5296+
52555297
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
52565298
DAGCombinerInfo &DCI) const {
52575299
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5275,6 +5317,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
52755317
case NVPTXISD::StoreRetvalV2:
52765318
case NVPTXISD::StoreRetvalV4:
52775319
return PerformStoreRetvalCombine(N);
5320+
case ISD::EXTRACT_VECTOR_ELT:
5321+
return PerformEXTRACTCombine(N, DCI);
52785322
}
52795323
return SDValue();
52805324
}

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,34 +1713,56 @@ def FUNSHFRCLAMP :
17131713
// BFE - bit-field extract
17141714
//
17151715

1716-
// Template for BFE instructions. Takes four args,
1717-
// [dest (reg), src (reg), start (reg or imm), end (reg or imm)].
1716+
// Template for BFE/BFI instructions.
1717+
// Args: [dest (reg), src (reg), start (reg or imm), end (reg or imm)].
17181718
// Start may be an imm only if end is also an imm. FIXME: Is this a
17191719
// restriction in PTX?
17201720
//
17211721
// dest and src may be int32 or int64, but start and end are always int32.
1722-
multiclass BFE<string TyStr, RegisterClass RC> {
1722+
multiclass BFX<string Instr, RegisterClass RC> {
17231723
def rrr
17241724
: NVPTXInst<(outs RC:$d),
17251725
(ins RC:$a, Int32Regs:$b, Int32Regs:$c),
1726-
!strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
1726+
!strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
17271727
def rri
17281728
: NVPTXInst<(outs RC:$d),
17291729
(ins RC:$a, Int32Regs:$b, i32imm:$c),
1730-
!strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
1730+
!strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
17311731
def rii
17321732
: NVPTXInst<(outs RC:$d),
17331733
(ins RC:$a, i32imm:$b, i32imm:$c),
1734-
!strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
1734+
!strconcat(Instr, " \t$d, $a, $b, $c;"), []>;
17351735
}
17361736

17371737
let hasSideEffects = false in {
1738-
defm BFE_S32 : BFE<"s32", Int32Regs>;
1739-
defm BFE_U32 : BFE<"u32", Int32Regs>;
1740-
defm BFE_S64 : BFE<"s64", Int64Regs>;
1741-
defm BFE_U64 : BFE<"u64", Int64Regs>;
1738+
defm BFE_S32 : BFX<"bfe.s32", Int32Regs>;
1739+
defm BFE_U32 : BFX<"bfe.u32", Int32Regs>;
1740+
defm BFE_S64 : BFX<"bfe.s64", Int64Regs>;
1741+
defm BFE_U64 : BFX<"bfe.u64", Int64Regs>;
1742+
1743+
defm BFI_S32 : BFX<"bfi.s32", Int32Regs>;
1744+
defm BFI_U32 : BFX<"bfi.u32", Int32Regs>;
1745+
defm BFI_S64 : BFX<"bfi.s64", Int64Regs>;
1746+
defm BFI_U64 : BFX<"bfi.u64", Int64Regs>;
17421747
}
17431748

1749+
// Common byte extraction patterns
1750+
def : Pat<(i16 (sext_inreg (trunc Int32Regs:$s), i8)),
1751+
(CVT_s8_s32 Int32Regs:$s, CvtNONE)>;
1752+
def : Pat<(i16 (sext_inreg (trunc (srl (i32 Int32Regs:$s), (i32 imm:$o))), i8)),
1753+
(CVT_s8_s32 (BFE_S32rii Int32Regs:$s, imm:$o, 8), CvtNONE)>;
1754+
def : Pat<(sext_inreg (srl (i32 Int32Regs:$s), (i32 imm:$o)), i8),
1755+
(BFE_S32rii Int32Regs:$s, imm:$o, 8)>;
1756+
def : Pat<(i16 (sra (i16 (trunc Int32Regs:$s)), (i32 8))),
1757+
(CVT_s8_s32 (BFE_S32rii Int32Regs:$s, 8, 8), CvtNONE)>;
1758+
1759+
def : Pat<(sext_inreg (srl (i64 Int64Regs:$s), (i32 imm:$o)), i8),
1760+
(BFE_S64rii Int64Regs:$s, imm:$o, 8)>;
1761+
def : Pat<(i16 (sext_inreg (trunc Int64Regs:$s), i8)),
1762+
(CVT_s8_s64 Int64Regs:$s, CvtNONE)>;
1763+
def : Pat<(i16 (sext_inreg (trunc (srl (i64 Int64Regs:$s), (i32 imm:$o))), i8)),
1764+
(CVT_s8_s64 (BFE_S64rii Int64Regs:$s, imm:$o, 8), CvtNONE)>;
1765+
17441766
//-----------------------------------
17451767
// Comparison instructions (setp, set)
17461768
//-----------------------------------
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_35 -verify-machineinstrs | FileCheck %s
2+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_35 | %ptxas-verify %}
3+
4+
5+
; CHECK-LABEL: test_v2i8
6+
; CHECK-DAG: ld.param.u16 [[A:%rs[0-9+]]], [test_v2i8_param_0];
7+
; CHECK-DAG: cvt.s16.s8 [[E0:%rs[0-9+]]], [[A]];
8+
; CHECK-DAG: shr.s16 [[E1:%rs[0-9+]]], [[A]], 8;
9+
define i16 @test_v2i8(i16 %a) {
10+
%v = bitcast i16 %a to <2 x i8>
11+
%r0 = extractelement <2 x i8> %v, i64 0
12+
%r1 = extractelement <2 x i8> %v, i64 1
13+
%r0i = sext i8 %r0 to i16
14+
%r1i = sext i8 %r1 to i16
15+
%r01 = add i16 %r0i, %r1i
16+
ret i16 %r01
17+
}
18+
19+
; CHECK-LABEL: test_v4i8
20+
; CHECK: ld.param.u32 [[R:%r[0-9+]]], [test_v4i8_param_0];
21+
; CHECK-DAG: cvt.s8.s32 [[E0:%rs[0-9+]]], [[R]];
22+
; CHECK-DAG: bfe.s32 [[R1:%r[0-9+]]], [[R]], 8, 8;
23+
; CHECK-DAG: cvt.s8.s32 [[E1:%rs[0-9+]]], [[R1]];
24+
; CHECK-DAG: bfe.s32 [[R2:%r[0-9+]]], [[R]], 16, 8;
25+
; CHECK-DAG: cvt.s8.s32 [[E2:%rs[0-9+]]], [[R2]];
26+
; CHECK-DAG: bfe.s32 [[R3:%r[0-9+]]], [[R]], 24, 8;
27+
; CHECK-DAG: cvt.s8.s32 [[E3:%rs[0-9+]]], [[R3]];
28+
define i16 @test_v4i8(i32 %a) {
29+
%v = bitcast i32 %a to <4 x i8>
30+
%r0 = extractelement <4 x i8> %v, i64 0
31+
%r1 = extractelement <4 x i8> %v, i64 1
32+
%r2 = extractelement <4 x i8> %v, i64 2
33+
%r3 = extractelement <4 x i8> %v, i64 3
34+
%r0i = sext i8 %r0 to i16
35+
%r1i = sext i8 %r1 to i16
36+
%r2i = sext i8 %r2 to i16
37+
%r3i = sext i8 %r3 to i16
38+
%r01 = add i16 %r0i, %r1i
39+
%r23 = add i16 %r2i, %r3i
40+
%r = add i16 %r01, %r23
41+
ret i16 %r
42+
}
43+
44+
; CHECK-LABEL: test_v8i8
45+
; CHECK: ld.param.u64 [[R:%rd[0-9+]]], [test_v8i8_param_0];
46+
; CHECK-DAG: cvt.s8.s64 [[E0:%rs[0-9+]]], [[R]];
47+
; Element 1 is still extracted by trunc, shr 8, not sure why.
48+
; CHECK-DAG: cvt.u16.u64 [[R01:%rs[0-9+]]], [[R]];
49+
; CHECK-DAG: shr.s16 [[E1:%rs[0-9+]]], [[R01]], 8;
50+
; CHECK-DAG: bfe.s64 [[RD2:%rd[0-9+]]], [[R]], 16, 8;
51+
; CHECK-DAG: cvt.s8.s64 [[E2:%rs[0-9+]]], [[RD2]];
52+
; CHECK-DAG: bfe.s64 [[RD3:%rd[0-9+]]], [[R]], 24, 8;
53+
; CHECK-DAG: cvt.s8.s64 [[E3:%rs[0-9+]]], [[RD3]];
54+
; CHECK-DAG: bfe.s64 [[RD4:%rd[0-9+]]], [[R]], 32, 8;
55+
; CHECK-DAG: cvt.s8.s64 [[E4:%rs[0-9+]]], [[RD4]];
56+
; CHECK-DAG: bfe.s64 [[RD5:%rd[0-9+]]], [[R]], 40, 8;
57+
; CHECK-DAG: cvt.s8.s64 [[E5:%rs[0-9+]]], [[RD5]];
58+
; CHECK-DAG: bfe.s64 [[RD6:%rd[0-9+]]], [[R]], 48, 8;
59+
; CHECK-DAG: cvt.s8.s64 [[E6:%rs[0-9+]]], [[RD6]];
60+
; CHECK-DAG: bfe.s64 [[RD7:%rd[0-9+]]], [[R]], 56, 8;
61+
; CHECK-DAG: cvt.s8.s64 [[E7:%rs[0-9+]]], [[RD7]];
62+
63+
define i16 @test_v8i8(i64 %a) {
64+
%v = bitcast i64 %a to <8 x i8>
65+
%r0 = extractelement <8 x i8> %v, i64 0
66+
%r1 = extractelement <8 x i8> %v, i64 1
67+
%r2 = extractelement <8 x i8> %v, i64 2
68+
%r3 = extractelement <8 x i8> %v, i64 3
69+
%r4 = extractelement <8 x i8> %v, i64 4
70+
%r5 = extractelement <8 x i8> %v, i64 5
71+
%r6 = extractelement <8 x i8> %v, i64 6
72+
%r7 = extractelement <8 x i8> %v, i64 7
73+
%r0i = sext i8 %r0 to i16
74+
%r1i = sext i8 %r1 to i16
75+
%r2i = sext i8 %r2 to i16
76+
%r3i = sext i8 %r3 to i16
77+
%r4i = sext i8 %r4 to i16
78+
%r5i = sext i8 %r5 to i16
79+
%r6i = sext i8 %r6 to i16
80+
%r7i = sext i8 %r7 to i16
81+
%r01 = add i16 %r0i, %r1i
82+
%r23 = add i16 %r2i, %r3i
83+
%r45 = add i16 %r4i, %r5i
84+
%r67 = add i16 %r6i, %r7i
85+
%r0123 = add i16 %r01, %r23
86+
%r4567 = add i16 %r45, %r67
87+
%r = add i16 %r0123, %r4567
88+
ret i16 %r
89+
}

0 commit comments

Comments
 (0)