Skip to content

Commit 470e7be

Browse files
ghehglanza
authored andcommitted
[CIR][CIRGen][Builtin][Neon] Lower neon_vshl_n_v and neon_vshlq_n_v (#965)
As title, but important step in this PR is to allow CIR ShiftOp to take vector of int type as input type. As result, I added a verifier to ShiftOp with 2 constraints 1. Input type either all vector or int type. This is consistent with LLVM::ShlOp, vector shift amount is expected. 2. In the spirit of C99 6.5.7.3, shift amount type must be the same as result type, the if vector type is used. (This is enforced in LLVM lowering for scalar int type).
1 parent c49428f commit 470e7be

File tree

10 files changed

+342
-118
lines changed

10 files changed

+342
-118
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,15 +1181,20 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
11811181
let summary = "Shift";
11821182
let description = [{
11831183
Shift `left` or `right`, according to the first operand. Second operand is
1184-
the shift target and the third the amount.
1184+
the shift target and the third the amount. Second and the thrid operand can
1185+
be either integer type or vector of integer type. However, they must be
1186+
either all vector of integer type, or all integer type. If they are vectors,
1187+
each vector element of the shift target is shifted by the corresponding
1188+
shift amount in the shift amount vector.
11851189

11861190
```mlir
11871191
%7 = cir.shift(left, %1 : !u64i, %4 : !s32i) -> !u64i
1192+
%8 = cir.shift(left, %2 : !cir.vector<!s32i x 2>, %3 : !cir.vector<!s32i x 2>) -> !cir.vector<!s32i x 2>
11881193
```
11891194
}];
11901195

1191-
let results = (outs CIR_IntType:$result);
1192-
let arguments = (ins CIR_IntType:$value, CIR_IntType:$amount,
1196+
let results = (outs CIR_AnyIntOrVecOfInt:$result);
1197+
let arguments = (ins CIR_AnyIntOrVecOfInt:$value, CIR_AnyIntOrVecOfInt:$amount,
11931198
UnitAttr:$isShiftleft);
11941199

11951200
let assemblyFormat = [{
@@ -1200,8 +1205,7 @@ def ShiftOp : CIR_Op<"shift", [Pure]> {
12001205
`)` `->` type($result) attr-dict
12011206
}];
12021207

1203-
// Already covered by the traits
1204-
let hasVerifier = 0;
1208+
let hasVerifier = 1;
12051209
}
12061210

12071211
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,9 @@ def IntegerVector : Type<
537537
]>, "!cir.vector of !cir.int"> {
538538
}
539539

540+
// Constraints
541+
def CIR_AnyIntOrVecOfInt: AnyTypeOf<[CIR_IntType, IntegerVector]>;
542+
540543
// Pointer to Arrays
541544
def ArrayPtr : Type<
542545
And<[

clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,10 +2205,10 @@ static int64_t getIntValueFromConstOp(mlir::Value val) {
22052205
}
22062206

22072207
/// This function `buildCommonNeonCallPattern0` implements a common way
2208-
// to generate neon intrinsic call that has following pattern:
2209-
// 1. There is a need to cast result of the intrinsic call back to
2210-
// expression type.
2211-
// 2. Function arg types are given, not deduced from actual arg types.
2208+
/// to generate neon intrinsic call that has following pattern:
2209+
/// 1. There is a need to cast result of the intrinsic call back to
2210+
/// expression type.
2211+
/// 2. Function arg types are given, not deduced from actual arg types.
22122212
static mlir::Value
22132213
buildCommonNeonCallPattern0(CIRGenFunction &cgf, std::string &intrincsName,
22142214
llvm::SmallVector<mlir::Type> argTypes,
@@ -2222,6 +2222,23 @@ buildCommonNeonCallPattern0(CIRGenFunction &cgf, std::string &intrincsName,
22222222
return builder.createBitcast(res, resultType);
22232223
}
22242224

2225+
/// Build a constant shift amount vector of `vecTy` to shift a vector
2226+
/// Here `shitfVal` is a constant integer that will be splated into a
2227+
/// a const vector of `vecTy` which is the return of this function
2228+
static mlir::Value buildNeonShiftVector(CIRGenBuilderTy &builder,
2229+
mlir::Value shiftVal,
2230+
mlir::cir::VectorType vecTy,
2231+
mlir::Location loc, bool neg) {
2232+
int shiftAmt = getIntValueFromConstOp(shiftVal);
2233+
llvm::SmallVector<mlir::Attribute> vecAttr{
2234+
vecTy.getSize(),
2235+
// ConstVectorAttr requires cir::IntAttr
2236+
mlir::cir::IntAttr::get(vecTy.getEltType(), shiftAmt)};
2237+
mlir::cir::ConstVectorAttr constVecAttr = mlir::cir::ConstVectorAttr::get(
2238+
vecTy, mlir::ArrayAttr::get(builder.getContext(), vecAttr));
2239+
return builder.create<mlir::cir::ConstantOp>(loc, vecTy, constVecAttr);
2240+
}
2241+
22252242
mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
22262243
unsigned builtinID, unsigned llvmIntrinsic, unsigned altLLVMIntrinsic,
22272244
const char *nameHint, unsigned modifier, const CallExpr *e,
@@ -2298,6 +2315,13 @@ mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
22982315
: "llvm.aarch64.neon.sqrdmulh.lane",
22992316
resTy, getLoc(e->getExprLoc()));
23002317
}
2318+
case NEON::BI__builtin_neon_vshl_n_v:
2319+
case NEON::BI__builtin_neon_vshlq_n_v: {
2320+
mlir::Location loc = getLoc(e->getExprLoc());
2321+
ops[1] = buildNeonShiftVector(builder, ops[1], vTy, loc, false);
2322+
return builder.create<mlir::cir::ShiftOp>(
2323+
loc, vTy, builder.createBitcast(ops[0], vTy), ops[1], true);
2324+
}
23012325
}
23022326

23032327
// This second switch is for the intrinsics that might have a more generic

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3953,6 +3953,23 @@ LogicalResult BinOp::verify() {
39533953
return mlir::success();
39543954
}
39553955

3956+
//===----------------------------------------------------------------------===//
3957+
// ShiftOp Definitions
3958+
//===----------------------------------------------------------------------===//
3959+
LogicalResult ShiftOp::verify() {
3960+
mlir::Operation *op = getOperation();
3961+
mlir::Type resType = getResult().getType();
3962+
bool isOp0Vec = mlir::isa<mlir::cir::VectorType>(op->getOperand(0).getType());
3963+
bool isOp1Vec = mlir::isa<mlir::cir::VectorType>(op->getOperand(1).getType());
3964+
if (isOp0Vec != isOp1Vec)
3965+
return emitOpError() << "input types cannot be one vector and one scalar";
3966+
if (isOp1Vec && op->getOperand(1).getType() != resType) {
3967+
return emitOpError() << "shift amount must have the type of the result "
3968+
<< "if it is vector shift";
3969+
}
3970+
return mlir::success();
3971+
}
3972+
39563973
//===----------------------------------------------------------------------===//
39573974
// LabelOp Definitions
39583975
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2773,24 +2773,40 @@ class CIRShiftOpLowering
27732773
auto cirAmtTy =
27742774
mlir::dyn_cast<mlir::cir::IntType>(op.getAmount().getType());
27752775
auto cirValTy = mlir::dyn_cast<mlir::cir::IntType>(op.getValue().getType());
2776+
2777+
// Operands could also be vector type
2778+
auto cirAmtVTy =
2779+
mlir::dyn_cast<mlir::cir::VectorType>(op.getAmount().getType());
2780+
auto cirValVTy =
2781+
mlir::dyn_cast<mlir::cir::VectorType>(op.getValue().getType());
27762782
auto llvmTy = getTypeConverter()->convertType(op.getType());
27772783
mlir::Value amt = adaptor.getAmount();
27782784
mlir::Value val = adaptor.getValue();
27792785

2780-
assert(cirValTy && cirAmtTy && "non-integer shift is NYI");
2781-
assert(cirValTy == op.getType() && "inconsistent operands' types NYI");
2786+
assert(((cirValTy && cirAmtTy) || (cirAmtVTy && cirValVTy)) &&
2787+
"shift input type must be integer or vector type, otherwise NYI");
2788+
2789+
assert((cirValTy == op.getType() || cirValVTy == op.getType()) &&
2790+
"inconsistent operands' types NYI");
27822791

27832792
// Ensure shift amount is the same type as the value. Some undefined
27842793
// behavior might occur in the casts below as per [C99 6.5.7.3].
2785-
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
2786-
!cirAmtTy.isSigned(), cirAmtTy.getWidth(),
2787-
cirValTy.getWidth());
2794+
// Vector type shift amount needs no cast as type consistency is expected to
2795+
// be already be enforced at CIRGen.
2796+
if (cirAmtTy)
2797+
amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy),
2798+
!cirAmtTy.isSigned(), cirAmtTy.getWidth(),
2799+
cirValTy.getWidth());
27882800

27892801
// Lower to the proper LLVM shift operation.
27902802
if (op.getIsShiftleft())
27912803
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
27922804
else {
2793-
if (cirValTy.isUnsigned())
2805+
bool isUnSigned =
2806+
cirValTy ? !cirValTy.isSigned()
2807+
: !mlir::cast<mlir::cir::IntType>(cirValVTy.getEltType())
2808+
.isSigned();
2809+
if (isUnSigned)
27942810
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
27952811
else
27962812
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);

0 commit comments

Comments
 (0)