Skip to content

Commit dcc7cfb

Browse files
seven-milelanza
authored andcommitted
[CIR][CodeGen][LowerToLLVM] Set calling convention for call ops (#836)
This PR implements the CIRGen and Lowering part of calling convention attribute of `cir.call`-like operations. Here we have **4 kinds of operations**: (direct or indirect) x (`call` or `try_call`). According to our need and feasibility of constructing a test case, this PR includes: * For CIRGen, only direct `call`. Until now, the only extra calling conventions are SPIR ones, which cannot be set from source code manually using attributes. Meanwhile, OpenCL C *does not allow* function pointers or exceptions, therefore the only case remaining is direct call. * For Lowering, direct and indirect `call`, but not any `try_call`. Although it's possible to write all 4 kinds of calls with calling convention in ClangIR assembly, exceptions is quite hard to write and read. I prefer source-code-level test for it when it's available in the future. For example, possibly C++ `thiscall` with exceptions. * Extra: the verification of calling convention consistency for direct `call` and direct `try_call`. All unsupported cases are guarded by assertions or MLIR diags.
1 parent 7365fd2 commit dcc7cfb

File tree

6 files changed

+80
-15
lines changed

6 files changed

+80
-15
lines changed

clang/lib/CIR/CodeGen/CIRGenCall.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
447447
mlir::cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
448448
mlir::cir::FuncOp directFuncOp,
449449
SmallVectorImpl<mlir::Value> &CIRCallArgs,
450-
mlir::Operation *InvokeDest,
450+
mlir::Operation *InvokeDest, mlir::cir::CallingConv callingConv,
451451
mlir::cir::ExtraFuncAttributesAttr extraFnAttrs) {
452452
auto &builder = CGF.getBuilder();
453453

@@ -468,6 +468,8 @@ buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
468468
}
469469

470470
mlir::cir::CallOp tryCallOp;
471+
// TODO(cir): Set calling convention for `cir.try_call`.
472+
assert(callingConv == mlir::cir::CallingConv::C && "NYI");
471473
if (indirectFuncTy) {
472474
tryCallOp = builder.createIndirectTryCallOp(callLoc, indirectFuncVal,
473475
indirectFuncTy, CIRCallArgs);
@@ -484,12 +486,15 @@ buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
484486
}
485487

486488
assert(builder.getInsertionBlock() && "expected valid basic block");
487-
if (indirectFuncTy)
489+
if (indirectFuncTy) {
490+
// TODO(cir): Set calling convention for indirect calls.
491+
assert(callingConv == mlir::cir::CallingConv::C && "NYI");
488492
return builder.createIndirectCallOp(
489493
callLoc, indirectFuncVal, indirectFuncTy, CIRCallArgs,
490494
mlir::cir::CallingConv::C, extraFnAttrs);
491-
return builder.createCallOp(callLoc, directFuncOp, CIRCallArgs,
492-
mlir::cir::CallingConv::C, extraFnAttrs);
495+
}
496+
return builder.createCallOp(callLoc, directFuncOp, CIRCallArgs, callingConv,
497+
extraFnAttrs);
493498
}
494499

495500
RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
@@ -765,9 +770,9 @@ RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
765770
auto extraFnAttrs = mlir::cir::ExtraFuncAttributesAttr::get(
766771
builder.getContext(), Attrs.getDictionary(builder.getContext()));
767772

768-
mlir::cir::CIRCallOpInterface callLikeOp =
769-
buildCallLikeOp(*this, callLoc, indirectFuncTy, indirectFuncVal,
770-
directFuncOp, CIRCallArgs, InvokeDest, extraFnAttrs);
773+
mlir::cir::CIRCallOpInterface callLikeOp = buildCallLikeOp(
774+
*this, callLoc, indirectFuncTy, indirectFuncVal, directFuncOp,
775+
CIRCallArgs, InvokeDest, callingConv, extraFnAttrs);
771776

772777
if (E)
773778
callLikeOp->setAttr(

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,6 +2711,12 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {
27112711
<< op->getOperand(i).getType() << " for operand number " << i;
27122712
}
27132713

2714+
// Calling convention must match.
2715+
if (callIf.getCallingConv() != fn.getCallingConv())
2716+
return op->emitOpError("calling convention mismatch: expected ")
2717+
<< stringifyCallingConv(fn.getCallingConv()) << ", but provided "
2718+
<< stringifyCallingConv(callIf.getCallingConv());
2719+
27142720
// Void function must not return any results.
27152721
if (fnType.isVoid() && op->getNumResults() != 0)
27162722
return op->emitOpError("callee returns void but call has results");

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -875,18 +875,24 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
875875
mlir::Block *landingPadBlock = nullptr) {
876876
llvm::SmallVector<mlir::Type, 8> llvmResults;
877877
auto cirResults = op->getResultTypes();
878+
auto callIf = cast<mlir::cir::CIRCallOpInterface>(op);
878879

879880
if (converter->convertTypes(cirResults, llvmResults).failed())
880881
return mlir::failure();
881882

883+
auto cconv = convertCallingConv(callIf.getCallingConv());
884+
882885
if (calleeAttr) { // direct call
883-
if (landingPadBlock)
884-
rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
886+
if (landingPadBlock) {
887+
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
885888
op, llvmResults, calleeAttr, callOperands, continueBlock,
886889
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
887-
else
888-
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(op, llvmResults,
889-
calleeAttr, callOperands);
890+
newOp.setCConv(cconv);
891+
} else {
892+
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
893+
op, llvmResults, calleeAttr, callOperands);
894+
newOp.setCConv(cconv);
895+
}
890896
} else { // indirect call
891897
assert(op->getOperands().size() &&
892898
"operands list must no be empty for the indirect call");
@@ -899,14 +905,17 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
899905
if (landingPadBlock) {
900906
auto llvmFnTy =
901907
dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp));
902-
rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
908+
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
903909
op, llvmFnTy, mlir::FlatSymbolRefAttr{}, callOperands, continueBlock,
904910
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
905-
} else
906-
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
911+
newOp.setCConv(cconv);
912+
} else {
913+
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
907914
op,
908915
dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp)),
909916
callOperands);
917+
newOp.setCConv(cconv);
918+
}
910919
}
911920
return mlir::success();
912921
}
@@ -932,6 +941,10 @@ class CIRTryCallLowering
932941
mlir::LogicalResult
933942
matchAndRewrite(mlir::cir::TryCallOp op, OpAdaptor adaptor,
934943
mlir::ConversionPatternRewriter &rewriter) const override {
944+
if (op.getCallingConv() != mlir::cir::CallingConv::C) {
945+
return op.emitError(
946+
"non-C calling convention is not implemented for try_call");
947+
}
935948
return rewriteToCallOrInvoke(
936949
op.getOperation(), adaptor.getOperands(), rewriter, getTypeConverter(),
937950
op.getCalleeAttr(), op.getCont(), op.getLandingPad());

clang/test/CIR/CodeGen/OpenCL/spir-calling-conv.cl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ kernel void bar(global int *A);
1515
// LLVM-DAG: define{{.*}} spir_kernel void @foo(
1616
kernel void foo(global int *A) {
1717
int id = get_dummy_id(0);
18+
// CIR: %{{[0-9]+}} = cir.call @get_dummy_id(%2) : (!s32i) -> !s32i cc(spir_function)
19+
// LLVM: %{{[a-z0-9_]+}} = call spir_func i32 @get_dummy_id(
1820
A[id] = id;
1921
bar(A);
22+
// CIR: cir.call @bar(%8) : (!cir.ptr<!s32i, addrspace(offload_global)>) -> () cc(spir_kernel)
23+
// LLVM: call spir_kernel void @bar(ptr addrspace(1)
2024
}

clang/test/CIR/IR/invalid.cir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,7 +1305,22 @@ module {
13051305
!s32i = !cir.int<s, 32>
13061306

13071307
module {
1308+
cir.func @subroutine() cc(spir_function) {
1309+
cir.return
1310+
}
1311+
1312+
cir.func @call_conv_match() {
1313+
// expected-error@+1 {{'cir.call' op calling convention mismatch: expected spir_function, but provided spir_kernel}}
1314+
cir.call @subroutine(): () -> !cir.void cc(spir_kernel)
1315+
cir.return
1316+
}
1317+
}
13081318

1319+
// -----
1320+
1321+
!s32i = !cir.int<s, 32>
1322+
1323+
module {
13091324
cir.func @test_bitcast_addrspace() {
13101325
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["tmp"] {alignment = 4 : i64}
13111326
// expected-error@+1 {{'cir.cast' op result type address space does not match the address space of the operand}}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: cir-translate -cir-to-llvmir %s -o %t.ll
2+
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=LLVM
3+
4+
!s32i = !cir.int<s, 32>
5+
!fnptr = !cir.ptr<!cir.func<!s32i (!s32i)>>
6+
7+
module {
8+
cir.func private @my_add(%a: !s32i, %b: !s32i) -> !s32i cc(spir_function)
9+
10+
cir.func @ind(%fnptr: !fnptr, %a : !s32i) {
11+
%1 = cir.call %fnptr(%a) : (!fnptr, !s32i) -> !s32i cc(spir_kernel)
12+
// LLVM: %{{[0-9]+}} = call spir_kernel i32 %{{[0-9]+}}(i32 %{{[0-9]+}})
13+
14+
%2 = cir.call %fnptr(%a) : (!fnptr, !s32i) -> !s32i cc(spir_function)
15+
// LLVM: %{{[0-9]+}} = call spir_func i32 %{{[0-9]+}}(i32 %{{[0-9]+}})
16+
17+
%3 = cir.call @my_add(%1, %2) : (!s32i, !s32i) -> !s32i cc(spir_function)
18+
// LLVM: %{{[0-9]+}} = call spir_func i32 @my_add(i32 %{{[0-9]+}}, i32 %{{[0-9]+}})
19+
20+
cir.return
21+
}
22+
}

0 commit comments

Comments
 (0)