Skip to content

Commit 62cdc2a

Browse files
authored
[NVPTX] Convert calls to indirect when call signature mismatches function signature (#107644)
When there is a function signature mismatch between a call instruction and the callee, lower the call to an indirect call. The current behavior is to produce direct calls that may or may not be valid PTX. Consider the following example with mismatching return types: ``` %struct.1 = type <{i64}> %struct.2 = type <{i64}> declare %struct.1 @callee() ... %call1 = call %struct.2 @callee() %call2 = call i64 @callee() ``` The return type of `callee` in PTX is `.b8 _[8]`. The return type of `%call1` will be the same and so the PTX has no problems. The return type of `%call2` will be `.b64`, so the types will not match and PTX will be unacceptable to ptxas. This despite all the types having the same size. The same is true for mismatching parameter types. If we instead convert these calls to indirect calls, we will generate functional PTX when the types have the same size. If they do not have the same size then the PTX will likely be incorrect, though this will not necessarily be caught by ptxas. Also, even if the sizes are the same, if the types differ then it is technically undefined behavior. This change allows for more flexibility in the bitcode that can be lowered to functioning PTX, at the cost of sometimes producing PTX that is less clearly wrong than it would have been previously (i.e. incorrect indirect calls are not as obviously wrong as incorrect direct calls). We consider it okay to generate PTX with undefined behavior as the behavior of calls with mismatching types is not explicitly defined.
1 parent 4eb9780 commit 62cdc2a

File tree

4 files changed

+178
-52
lines changed

4 files changed

+178
-52
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1658,6 +1658,15 @@ LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
16581658
return RetVal;
16591659
}
16601660

1661+
static bool shouldConvertToIndirectCall(const CallBase *CB,
1662+
const GlobalAddressSDNode *Func) {
1663+
if (!Func)
1664+
return false;
1665+
if (auto *CalleeFunc = dyn_cast<Function>(Func->getGlobal()))
1666+
return CB->getFunctionType() != CalleeFunc->getFunctionType();
1667+
return false;
1668+
}
1669+
16611670
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
16621671
SmallVectorImpl<SDValue> &InVals) const {
16631672

@@ -1972,10 +1981,14 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
19721981
VADeclareParam->getVTList(), DeclareParamOps);
19731982
}
19741983

1984+
// If the type of the callsite does not match that of the function, convert
1985+
// the callsite to an indirect call.
1986+
bool ConvertToIndirectCall = shouldConvertToIndirectCall(CB, Func);
1987+
19751988
// Both indirect calls and libcalls have nullptr Func. In order to distinguish
19761989
// between them we must rely on the call site value which is valid for
19771990
// indirect calls but is always null for libcalls.
1978-
bool isIndirectCall = !Func && CB;
1991+
bool isIndirectCall = (!Func && CB) || ConvertToIndirectCall;
19791992

19801993
if (isa<ExternalSymbolSDNode>(Callee)) {
19811994
Function* CalleeFunc = nullptr;
@@ -2027,6 +2040,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
20272040
Chain = DAG.getNode(Opcode, dl, PrintCallVTs, PrintCallOps);
20282041
InGlue = Chain.getValue(1);
20292042

2043+
if (ConvertToIndirectCall) {
2044+
// Copy the function ptr to a ptx register and use the register to call the
2045+
// function.
2046+
EVT DestVT = Callee.getValueType();
2047+
MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
2048+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2049+
unsigned DestReg =
2050+
RegInfo.createVirtualRegister(TLI.getRegClassFor(DestVT.getSimpleVT()));
2051+
auto RegCopy = DAG.getCopyToReg(DAG.getEntryNode(), dl, DestReg, Callee);
2052+
Callee = DAG.getCopyFromReg(RegCopy, dl, DestReg, DestVT);
2053+
}
2054+
20302055
// Ops to print out the function name
20312056
SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
20322057
SDValue CallVoidOps[] = { Chain, Callee, InGlue };

llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ target triple = "nvptx64-nvidia-cuda"
1717
; CHECK: st.param.b16 [param2+0], %rs1;
1818
; CHECK: st.param.b16 [param2+2], %rs2;
1919
; CHECK: .param .align 2 .b8 retval0[4];
20-
; CHECK: call.uni (retval0),
21-
; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE,
20+
; CHECK-NEXT: prototype_0 : .callprototype (.param .align 2 .b8 _[4]) _ (.param .b32 _, .param .b32 _, .param .align 2 .b8 _[4]);
21+
; CHECK-NEXT: call (retval0),
2222
define weak_odr void @foo() {
2323
entry:
2424
%call.i.i.i = tail call %"class.complex" @_Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE(i32 0, i32 0, ptr byval(%"class.complex") null)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s
2+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify %}
3+
4+
%struct.64 = type <{ i64 }>
5+
declare i64 @callee(ptr %p);
6+
declare i64 @callee_variadic(ptr %p, ...);
7+
8+
define %struct.64 @test_return_type_mismatch(ptr %p) {
9+
; CHECK-LABEL: test_return_type_mismatch(
10+
; CHECK: .param .align 1 .b8 retval0[8];
11+
; CHECK-NEXT: prototype_0 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
12+
; CHECK-NEXT: call (retval0),
13+
; CHECK-NEXT: %rd
14+
; CHECK-NEXT: (
15+
; CHECK-NEXT: param0
16+
; CHECK-NEXT: )
17+
; CHECK-NEXT: , prototype_0;
18+
%ret = call %struct.64 @callee(ptr %p)
19+
ret %struct.64 %ret
20+
}
21+
22+
define i64 @test_param_type_mismatch(ptr %p) {
23+
; CHECK-LABEL: test_param_type_mismatch(
24+
; CHECK: .param .b64 retval0;
25+
; CHECK-NEXT: prototype_1 : .callprototype (.param .b64 _) _ (.param .b64 _);
26+
; CHECK-NEXT: call (retval0),
27+
; CHECK-NEXT: %rd
28+
; CHECK-NEXT: (
29+
; CHECK-NEXT: param0
30+
; CHECK-NEXT: )
31+
; CHECK-NEXT: , prototype_1;
32+
%ret = call i64 @callee(i64 7)
33+
ret i64 %ret
34+
}
35+
36+
define i64 @test_param_count_mismatch(ptr %p) {
37+
; CHECK-LABEL: test_param_count_mismatch(
38+
; CHECK: .param .b64 retval0;
39+
; CHECK-NEXT: prototype_2 : .callprototype (.param .b64 _) _ (.param .b64 _, .param .b64 _);
40+
; CHECK-NEXT: call (retval0),
41+
; CHECK-NEXT: %rd
42+
; CHECK-NEXT: (
43+
; CHECK-NEXT: param0,
44+
; CHECK-NEXT: param1
45+
; CHECK-NEXT: )
46+
; CHECK-NEXT: , prototype_2;
47+
%ret = call i64 @callee(ptr %p, i64 7)
48+
ret i64 %ret
49+
}
50+
51+
define %struct.64 @test_return_type_mismatch_variadic(ptr %p) {
52+
; CHECK-LABEL: test_return_type_mismatch_variadic(
53+
; CHECK: .param .align 1 .b8 retval0[8];
54+
; CHECK-NEXT: prototype_3 : .callprototype (.param .align 1 .b8 _[8]) _ (.param .b64 _);
55+
; CHECK-NEXT: call (retval0),
56+
; CHECK-NEXT: %rd
57+
; CHECK-NEXT: (
58+
; CHECK-NEXT: param0
59+
; CHECK-NEXT: )
60+
; CHECK-NEXT: , prototype_3;
61+
%ret = call %struct.64 (ptr, ...) @callee_variadic(ptr %p)
62+
ret %struct.64 %ret
63+
}
64+
65+
define i64 @test_param_type_mismatch_variadic(ptr %p) {
66+
; CHECK-LABEL: test_param_type_mismatch_variadic(
67+
; CHECK: .param .b64 retval0;
68+
; CHECK-NEXT: call.uni (retval0),
69+
; CHECK-NEXT: callee_variadic
70+
; CHECK-NEXT: (
71+
; CHECK-NEXT: param0,
72+
; CHECK-NEXT: param1
73+
; CHECK-NEXT: )
74+
%ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
75+
ret i64 %ret
76+
}
77+
78+
define i64 @test_param_count_mismatch_variadic(ptr %p) {
79+
; CHECK-LABEL: test_param_count_mismatch_variadic(
80+
; CHECK: .param .b64 retval0;
81+
; CHECK-NEXT: call.uni (retval0),
82+
; CHECK-NEXT: callee_variadic
83+
; CHECK-NEXT: (
84+
; CHECK-NEXT: param0,
85+
; CHECK-NEXT: param1
86+
; CHECK-NEXT: )
87+
%ret = call i64 (ptr, ...) @callee_variadic(ptr %p, i64 7)
88+
ret i64 %ret
89+
}

llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll

Lines changed: 61 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,24 @@ define void @grid_const_escape(ptr byval(%struct.s) align 4 %input) {
7272
; PTX-LABEL: grid_const_escape(
7373
; PTX: {
7474
; PTX-NEXT: .reg .b32 %r<3>;
75-
; PTX-NEXT: .reg .b64 %rd<4>;
75+
; PTX-NEXT: .reg .b64 %rd<5>;
7676
; PTX-EMPTY:
7777
; PTX-NEXT: // %bb.0:
78-
; PTX-NEXT: mov.b64 %rd1, grid_const_escape_param_0;
79-
; PTX-NEXT: mov.u64 %rd2, %rd1;
80-
; PTX-NEXT: cvta.param.u64 %rd3, %rd2;
78+
; PTX-NEXT: mov.b64 %rd2, grid_const_escape_param_0;
79+
; PTX-NEXT: mov.u64 %rd3, %rd2;
80+
; PTX-NEXT: cvta.param.u64 %rd4, %rd3;
81+
; PTX-NEXT: mov.u64 %rd1, escape;
8182
; PTX-NEXT: { // callseq 0, 0
8283
; PTX-NEXT: .param .b64 param0;
83-
; PTX-NEXT: st.param.b64 [param0+0], %rd3;
84+
; PTX-NEXT: st.param.b64 [param0+0], %rd4;
8485
; PTX-NEXT: .param .b32 retval0;
85-
; PTX-NEXT: call.uni (retval0),
86-
; PTX-NEXT: escape,
86+
; PTX-NEXT: prototype_0 : .callprototype (.param .b32 _) _ (.param .b64 _);
87+
; PTX-NEXT: call (retval0),
88+
; PTX-NEXT: %rd1,
8789
; PTX-NEXT: (
8890
; PTX-NEXT: param0
89-
; PTX-NEXT: );
91+
; PTX-NEXT: )
92+
; PTX-NEXT: , prototype_0;
9093
; PTX-NEXT: ld.param.b32 %r1, [retval0+0];
9194
; PTX-NEXT: } // callseq 0
9295
; PTX-NEXT: ret;
@@ -107,36 +110,39 @@ define void @multiple_grid_const_escape(ptr byval(%struct.s) align 4 %input, i32
107110
; PTX-NEXT: .reg .b64 %SP;
108111
; PTX-NEXT: .reg .b64 %SPL;
109112
; PTX-NEXT: .reg .b32 %r<4>;
110-
; PTX-NEXT: .reg .b64 %rd<9>;
113+
; PTX-NEXT: .reg .b64 %rd<10>;
111114
; PTX-EMPTY:
112115
; PTX-NEXT: // %bb.0:
113116
; PTX-NEXT: mov.u64 %SPL, __local_depot3;
114117
; PTX-NEXT: cvta.local.u64 %SP, %SPL;
115-
; PTX-NEXT: mov.b64 %rd1, multiple_grid_const_escape_param_0;
116-
; PTX-NEXT: mov.b64 %rd2, multiple_grid_const_escape_param_2;
117-
; PTX-NEXT: mov.u64 %rd3, %rd2;
118+
; PTX-NEXT: mov.b64 %rd2, multiple_grid_const_escape_param_0;
119+
; PTX-NEXT: mov.b64 %rd3, multiple_grid_const_escape_param_2;
120+
; PTX-NEXT: mov.u64 %rd4, %rd3;
118121
; PTX-NEXT: ld.param.u32 %r1, [multiple_grid_const_escape_param_1];
119-
; PTX-NEXT: cvta.param.u64 %rd4, %rd3;
120-
; PTX-NEXT: mov.u64 %rd5, %rd1;
121-
; PTX-NEXT: cvta.param.u64 %rd6, %rd5;
122-
; PTX-NEXT: add.u64 %rd7, %SP, 0;
123-
; PTX-NEXT: add.u64 %rd8, %SPL, 0;
124-
; PTX-NEXT: st.local.u32 [%rd8], %r1;
122+
; PTX-NEXT: cvta.param.u64 %rd5, %rd4;
123+
; PTX-NEXT: mov.u64 %rd6, %rd2;
124+
; PTX-NEXT: cvta.param.u64 %rd7, %rd6;
125+
; PTX-NEXT: add.u64 %rd8, %SP, 0;
126+
; PTX-NEXT: add.u64 %rd9, %SPL, 0;
127+
; PTX-NEXT: st.local.u32 [%rd9], %r1;
128+
; PTX-NEXT: mov.u64 %rd1, escape3;
125129
; PTX-NEXT: { // callseq 1, 0
126130
; PTX-NEXT: .param .b64 param0;
127-
; PTX-NEXT: st.param.b64 [param0+0], %rd6;
131+
; PTX-NEXT: st.param.b64 [param0+0], %rd7;
128132
; PTX-NEXT: .param .b64 param1;
129-
; PTX-NEXT: st.param.b64 [param1+0], %rd7;
133+
; PTX-NEXT: st.param.b64 [param1+0], %rd8;
130134
; PTX-NEXT: .param .b64 param2;
131-
; PTX-NEXT: st.param.b64 [param2+0], %rd4;
135+
; PTX-NEXT: st.param.b64 [param2+0], %rd5;
132136
; PTX-NEXT: .param .b32 retval0;
133-
; PTX-NEXT: call.uni (retval0),
134-
; PTX-NEXT: escape3,
137+
; PTX-NEXT: prototype_1 : .callprototype (.param .b32 _) _ (.param .b64 _, .param .b64 _, .param .b64 _);
138+
; PTX-NEXT: call (retval0),
139+
; PTX-NEXT: %rd1,
135140
; PTX-NEXT: (
136141
; PTX-NEXT: param0,
137142
; PTX-NEXT: param1,
138143
; PTX-NEXT: param2
139-
; PTX-NEXT: );
144+
; PTX-NEXT: )
145+
; PTX-NEXT: , prototype_1;
140146
; PTX-NEXT: ld.param.b32 %r2, [retval0+0];
141147
; PTX-NEXT: } // callseq 1
142148
; PTX-NEXT: ret;
@@ -221,26 +227,29 @@ define void @grid_const_partial_escape(ptr byval(i32) %input, ptr %output) {
221227
; PTX-LABEL: grid_const_partial_escape(
222228
; PTX: {
223229
; PTX-NEXT: .reg .b32 %r<5>;
224-
; PTX-NEXT: .reg .b64 %rd<6>;
230+
; PTX-NEXT: .reg .b64 %rd<7>;
225231
; PTX-EMPTY:
226232
; PTX-NEXT: // %bb.0:
227-
; PTX-NEXT: mov.b64 %rd1, grid_const_partial_escape_param_0;
228-
; PTX-NEXT: ld.param.u64 %rd2, [grid_const_partial_escape_param_1];
229-
; PTX-NEXT: cvta.to.global.u64 %rd3, %rd2;
230-
; PTX-NEXT: mov.u64 %rd4, %rd1;
231-
; PTX-NEXT: cvta.param.u64 %rd5, %rd4;
232-
; PTX-NEXT: ld.u32 %r1, [%rd5];
233+
; PTX-NEXT: mov.b64 %rd2, grid_const_partial_escape_param_0;
234+
; PTX-NEXT: ld.param.u64 %rd3, [grid_const_partial_escape_param_1];
235+
; PTX-NEXT: cvta.to.global.u64 %rd4, %rd3;
236+
; PTX-NEXT: mov.u64 %rd5, %rd2;
237+
; PTX-NEXT: cvta.param.u64 %rd6, %rd5;
238+
; PTX-NEXT: ld.u32 %r1, [%rd6];
233239
; PTX-NEXT: add.s32 %r2, %r1, %r1;
234-
; PTX-NEXT: st.global.u32 [%rd3], %r2;
240+
; PTX-NEXT: st.global.u32 [%rd4], %r2;
241+
; PTX-NEXT: mov.u64 %rd1, escape;
235242
; PTX-NEXT: { // callseq 2, 0
236243
; PTX-NEXT: .param .b64 param0;
237-
; PTX-NEXT: st.param.b64 [param0+0], %rd5;
244+
; PTX-NEXT: st.param.b64 [param0+0], %rd6;
238245
; PTX-NEXT: .param .b32 retval0;
239-
; PTX-NEXT: call.uni (retval0),
240-
; PTX-NEXT: escape,
246+
; PTX-NEXT: prototype_2 : .callprototype (.param .b32 _) _ (.param .b64 _);
247+
; PTX-NEXT: call (retval0),
248+
; PTX-NEXT: %rd1,
241249
; PTX-NEXT: (
242250
; PTX-NEXT: param0
243-
; PTX-NEXT: );
251+
; PTX-NEXT: )
252+
; PTX-NEXT: , prototype_2;
244253
; PTX-NEXT: ld.param.b32 %r3, [retval0+0];
245254
; PTX-NEXT: } // callseq 2
246255
; PTX-NEXT: ret;
@@ -266,27 +275,30 @@ define i32 @grid_const_partial_escapemem(ptr byval(%struct.s) %input, ptr %outpu
266275
; PTX-LABEL: grid_const_partial_escapemem(
267276
; PTX: {
268277
; PTX-NEXT: .reg .b32 %r<6>;
269-
; PTX-NEXT: .reg .b64 %rd<6>;
278+
; PTX-NEXT: .reg .b64 %rd<7>;
270279
; PTX-EMPTY:
271280
; PTX-NEXT: // %bb.0:
272-
; PTX-NEXT: mov.b64 %rd1, grid_const_partial_escapemem_param_0;
273-
; PTX-NEXT: ld.param.u64 %rd2, [grid_const_partial_escapemem_param_1];
274-
; PTX-NEXT: cvta.to.global.u64 %rd3, %rd2;
275-
; PTX-NEXT: mov.u64 %rd4, %rd1;
276-
; PTX-NEXT: cvta.param.u64 %rd5, %rd4;
277-
; PTX-NEXT: ld.u32 %r1, [%rd5];
278-
; PTX-NEXT: ld.u32 %r2, [%rd5+4];
279-
; PTX-NEXT: st.global.u64 [%rd3], %rd5;
281+
; PTX-NEXT: mov.b64 %rd2, grid_const_partial_escapemem_param_0;
282+
; PTX-NEXT: ld.param.u64 %rd3, [grid_const_partial_escapemem_param_1];
283+
; PTX-NEXT: cvta.to.global.u64 %rd4, %rd3;
284+
; PTX-NEXT: mov.u64 %rd5, %rd2;
285+
; PTX-NEXT: cvta.param.u64 %rd6, %rd5;
286+
; PTX-NEXT: ld.u32 %r1, [%rd6];
287+
; PTX-NEXT: ld.u32 %r2, [%rd6+4];
288+
; PTX-NEXT: st.global.u64 [%rd4], %rd6;
280289
; PTX-NEXT: add.s32 %r3, %r1, %r2;
290+
; PTX-NEXT: mov.u64 %rd1, escape;
281291
; PTX-NEXT: { // callseq 3, 0
282292
; PTX-NEXT: .param .b64 param0;
283-
; PTX-NEXT: st.param.b64 [param0+0], %rd5;
293+
; PTX-NEXT: st.param.b64 [param0+0], %rd6;
284294
; PTX-NEXT: .param .b32 retval0;
285-
; PTX-NEXT: call.uni (retval0),
286-
; PTX-NEXT: escape,
295+
; PTX-NEXT: prototype_3 : .callprototype (.param .b32 _) _ (.param .b64 _);
296+
; PTX-NEXT: call (retval0),
297+
; PTX-NEXT: %rd1,
287298
; PTX-NEXT: (
288299
; PTX-NEXT: param0
289-
; PTX-NEXT: );
300+
; PTX-NEXT: )
301+
; PTX-NEXT: , prototype_3;
290302
; PTX-NEXT: ld.param.b32 %r4, [retval0+0];
291303
; PTX-NEXT: } // callseq 3
292304
; PTX-NEXT: st.param.b32 [func_retval0+0], %r3;

0 commit comments

Comments
 (0)