Skip to content

Commit ed5bf45

Browse files
NimishMishrakiranchandramohan
authored andcommitted
Cherrypicked atomic operation based changes from llvm main
1 parent ca0677d commit ed5bf45

File tree

4 files changed

+484
-178
lines changed

4 files changed

+484
-178
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -317,14 +317,12 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
317317

318318
def YieldOp : OpenMP_Op<"yield",
319319
[NoSideEffect, ReturnLike, Terminator,
320-
ParentOneOf<["WsLoopOp", "ReductionDeclareOp"]>]> {
320+
ParentOneOf<["WsLoopOp", "ReductionDeclareOp", "AtomicUpdateOp"]>]> {
321321
let summary = "loop yield and termination operation";
322322
let description = [{
323323
"omp.yield" yields SSA values from the OpenMP dialect op region and
324324
terminates the region. The semantics of how the values are yielded is
325325
defined by the parent operation.
326-
If "omp.yield" has any operands, the operands must match the parent
327-
operation's results.
328326
}];
329327

330328
let arguments = (ins Variadic<AnyType>:$results);
@@ -559,11 +557,11 @@ def TaskwaitOp : OpenMP_Op<"taskwait"> {
559557
// value of the clause) here decomposes handling of this construct into a
560558
// two-step process.
561559

562-
def AtomicReadOp : OpenMP_Op<"atomic.read"> {
563-
let arguments = (ins OpenMP_PointerLikeType:$address,
560+
def AtomicReadOp : OpenMP_Op<"atomic.read", [AllTypesMatch<["x", "v"]>]> {
561+
let arguments = (ins OpenMP_PointerLikeType:$x,
562+
OpenMP_PointerLikeType:$v,
564563
DefaultValuedAttr<I64Attr, "0">:$hint,
565564
OptionalAttr<MemoryOrderKind>:$memory_order);
566-
let results = (outs AnyType);
567565
let parser = [{ return parseAtomicReadOp(parser, result); }];
568566
let printer = [{ return printAtomicReadOp(p, *this); }];
569567
let verifier = [{ return verifyAtomicReadOp(*this); }];
@@ -606,18 +604,25 @@ def AtomicBinOpKindAttr : I64EnumAttr<
606604
let symbolToStringFnName = "AtomicBinOpKindToString";
607605
}
608606

609-
def AtomicUpdateOp : OpenMP_Op<"atomic.update"> {
607+
def AtomicUpdateOp : OpenMP_Op<"atomic.update", [SingleBlockImplicitTerminator<"YieldOp">]> {
610608
let arguments = (ins OpenMP_PointerLikeType:$x,
611-
AnyType:$expr,
612-
UnitAttr:$isXBinopExpr,
613-
AtomicBinOpKindAttr:$binop,
614609
DefaultValuedAttr<I64Attr, "0">:$hint,
615610
OptionalAttr<MemoryOrderKind>:$memory_order);
611+
let regions = (region SizedRegion<1>:$region);
616612
let parser = [{ return parseAtomicUpdateOp(parser, result); }];
617613
let printer = [{ return printAtomicUpdateOp(p, *this); }];
618614
let verifier = [{ return verifyAtomicUpdateOp(*this); }];
619615
}
620616

617+
def AtomicCaptureOp : OpenMP_Op<"atomic.capture", [SingleBlockImplicitTerminator<"TerminatorOp">]>{
618+
let arguments = (ins DefaultValuedAttr<I64Attr, "0">:$hint,
619+
OptionalAttr<MemoryOrderKind>:$memory_order);
620+
let regions = (region SizedRegion<1>:$region);
621+
let parser = [{ return parseAtomicCaptureOp(parser, result); }];
622+
let printer = [{ return printAtomicCaptureOp(p, *this); }];
623+
let verifier = [{ return verifyAtomicCaptureOp(*this); }];
624+
}
625+
621626
//===----------------------------------------------------------------------===//
622627
// 2.19.5.7 declare reduction Directive
623628
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 101 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,32 +1227,28 @@ static LogicalResult verifyOrderedRegionOp(OrderedRegionOp op) {
12271227
/// address ::= operand `:` type
12281228
static ParseResult parseAtomicReadOp(OpAsmParser &parser,
12291229
OperationState &result) {
1230-
OpAsmParser::OperandType address;
1230+
OpAsmParser::OperandType x, v;
12311231
Type addressType;
12321232
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
12331233
SmallVector<int> segments;
12341234

1235-
if (parser.parseOperand(address) ||
1235+
if (parser.parseOperand(v) || parser.parseEqual() || parser.parseOperand(x) ||
12361236
parseClauses(parser, result, clauses, segments) ||
12371237
parser.parseColonType(addressType) ||
1238-
parser.resolveOperand(address, addressType, result.operands))
1238+
parser.resolveOperand(x, addressType, result.operands) ||
1239+
parser.resolveOperand(v, addressType, result.operands))
12391240
return failure();
1240-
1241-
SmallVector<Type> resultType;
1242-
if (parser.parseArrowTypeList(resultType))
1243-
return failure();
1244-
result.addTypes(resultType);
12451241
return success();
12461242
}
12471243

12481244
/// Printer for AtomicReadOp
12491245
static void printAtomicReadOp(OpAsmPrinter &p, AtomicReadOp op) {
1250-
p << " " << op.address() << " ";
1246+
p << " " << op.v() << " = " << op.x() << " ";
12511247
if (op.memory_order())
12521248
p << "memory_order(" << op.memory_order().getValue() << ") ";
12531249
if (op.hintAttr())
12541250
printSynchronizationHint(p << " ", op, op.hintAttr());
1255-
p << ": " << op.address().getType() << " -> " << op.getType();
1251+
p << ": " << op.x().getType();
12561252
return;
12571253
}
12581254

@@ -1264,6 +1260,9 @@ static LogicalResult verifyAtomicReadOp(AtomicReadOp op) {
12641260
return op.emitError(
12651261
"memory-order must not be acq_rel or release for atomic reads");
12661262
}
1263+
if (op.x() == op.v())
1264+
return op.emitError(
1265+
"read and write must not be to the same location for atomic reads");
12671266
return verifySynchronizationHint(op, op.hint());
12681267
}
12691268

@@ -1284,7 +1283,7 @@ static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
12841283
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
12851284
SmallVector<int> segments;
12861285

1287-
if (parser.parseOperand(address) || parser.parseComma() ||
1286+
if (parser.parseOperand(address) || parser.parseEqual() ||
12881287
parser.parseOperand(value) ||
12891288
parseClauses(parser, result, clauses, segments) ||
12901289
parser.parseColonType(addrType) || parser.parseComma() ||
@@ -1297,7 +1296,7 @@ static ParseResult parseAtomicWriteOp(OpAsmParser &parser,
12971296

12981297
/// Printer for AtomicWriteOp
12991298
static void printAtomicWriteOp(OpAsmPrinter &p, AtomicWriteOp op) {
1300-
p << " " << op.address() << ", " << op.value() << " ";
1299+
p << " " << op.address() << " = " << op.value() << " ";
13011300
if (op.memory_order())
13021301
p << "memory_order(" << op.memory_order() << ") ";
13031302
if (op.hintAttr())
@@ -1328,61 +1327,28 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
13281327
OperationState &result) {
13291328
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
13301329
SmallVector<int> segments;
1331-
OpAsmParser::OperandType x, y, z;
1332-
Type xType, exprType;
1333-
StringRef binOp;
1334-
1335-
// x = y `op` z : xtype, exprtype
1336-
if (parser.parseOperand(x) || parser.parseEqual() || parser.parseOperand(y) ||
1337-
parser.parseKeyword(&binOp) || parser.parseOperand(z) ||
1338-
parseClauses(parser, result, clauses, segments) || parser.parseColon() ||
1339-
parser.parseType(xType) || parser.parseComma() ||
1340-
parser.parseType(exprType) ||
1341-
parser.resolveOperand(x, xType, result.operands)) {
1330+
OpAsmParser::OperandType x, expr;
1331+
Type xType;
1332+
1333+
if (parseClauses(parser, result, clauses, segments) ||
1334+
parser.parseOperand(x) || parser.parseColon() ||
1335+
parser.parseType(xType) ||
1336+
parser.resolveOperand(x, xType, result.operands) ||
1337+
parser.parseRegion(*result.addRegion())) {
13421338
return failure();
13431339
}
1344-
1345-
auto binOpEnum = AtomicBinOpKindToEnum(binOp.upper());
1346-
if (!binOpEnum)
1347-
return parser.emitError(parser.getNameLoc())
1348-
<< "invalid atomic bin op in atomic update\n";
1349-
auto attr =
1350-
parser.getBuilder().getI64IntegerAttr((int64_t)binOpEnum.getValue());
1351-
result.addAttribute("binop", attr);
1352-
1353-
OpAsmParser::OperandType expr;
1354-
if (x.name == y.name && x.number == y.number) {
1355-
expr = z;
1356-
result.addAttribute("isXBinopExpr", parser.getBuilder().getUnitAttr());
1357-
} else if (x.name == z.name && x.number == z.number) {
1358-
expr = y;
1359-
} else {
1360-
return parser.emitError(parser.getNameLoc())
1361-
<< "atomic update variable " << x.name
1362-
<< " not found in the RHS of the assignment statement in an"
1363-
" atomic.update operation";
1364-
}
1365-
return parser.resolveOperand(expr, exprType, result.operands);
1340+
return success();
13661341
}
13671342

13681343
/// Printer for AtomicUpdateOp
13691344
static void printAtomicUpdateOp(OpAsmPrinter &p, AtomicUpdateOp op) {
1370-
p << " " << op.x() << " = ";
1371-
Value y, z;
1372-
if (op.isXBinopExpr()) {
1373-
y = op.x();
1374-
z = op.expr();
1375-
} else {
1376-
y = op.expr();
1377-
z = op.x();
1378-
}
1379-
p << y << " " << AtomicBinOpKindToString(op.binop()).lower() << " " << z
1380-
<< " ";
1345+
p << " ";
13811346
if (op.memory_order())
13821347
p << "memory_order(" << op.memory_order() << ") ";
13831348
if (op.hintAttr())
13841349
printSynchronizationHint(p, op, op.hintAttr());
1385-
p << ": " << op.x().getType() << ", " << op.expr().getType();
1350+
p << op.x() << " : " << op.x().getType();
1351+
p.printRegion(op.region());
13861352
}
13871353

13881354
/// Verifier for AtomicUpdateOp
@@ -1393,6 +1359,84 @@ static LogicalResult verifyAtomicUpdateOp(AtomicUpdateOp op) {
13931359
return op.emitError(
13941360
"memory-order must not be acq_rel or acquire for atomic updates");
13951361
}
1362+
if (op.region().getNumArguments() != 1)
1363+
return op.emitError("the region must accept exactly one argument");
1364+
1365+
if (op.x().getType().cast<PointerLikeType>().getElementType() !=
1366+
op.region().getArgument(0).getType()) {
1367+
return op.emitError(
1368+
"the type of the operand must be a pointer type whose "
1369+
"element type is the same as that of the region argument");
1370+
}
1371+
1372+
YieldOp yieldOp = *op.region().getOps<YieldOp>().begin();
1373+
if (yieldOp.results().size() != 1)
1374+
return op.emitError("only updated value must be returned");
1375+
if (yieldOp.results().front().getType() !=
1376+
op.region().getArgument(0).getType())
1377+
return op.emitError("input and yielded value must have the same type");
1378+
return success();
1379+
}
1380+
1381+
//===----------------------------------------------------------------------===//
1382+
// AtomicCaptureOp
1383+
//===----------------------------------------------------------------------===//
1384+
1385+
/// Parser for AtomicCaptureOp
1386+
static LogicalResult parseAtomicCaptureOp(OpAsmParser &parser,
1387+
OperationState &result) {
1388+
SmallVector<ClauseType> clauses = {memoryOrderClause, hintClause};
1389+
SmallVector<int> segments;
1390+
if (parseClauses(parser, result, clauses, segments) ||
1391+
parser.parseRegion(*result.addRegion()))
1392+
return failure();
1393+
return success();
1394+
}
1395+
1396+
/// Printer for AtomicCaptureOp
1397+
static void printAtomicCaptureOp(OpAsmPrinter &p, AtomicCaptureOp op) {
1398+
if (op.memory_order())
1399+
p << "memory_order(" << op.memory_order() << ") ";
1400+
if (op.hintAttr())
1401+
printSynchronizationHint(p, op, op.hintAttr());
1402+
p.printRegion(op.region());
1403+
}
1404+
1405+
/// Verifier for AtomicCaptureOp
1406+
static LogicalResult verifyAtomicCaptureOp(AtomicCaptureOp op) {
1407+
Block::OpListType &ops = op.region().front().getOperations();
1408+
if (ops.size() != 3)
1409+
return emitError(op.getLoc())
1410+
<< "expected three operations in omp.atomic.capture region (one"
1411+
" terminator, and two atomic ops";
1412+
auto &firstOp = ops.front();
1413+
auto &secondOp = *ops.getNextNode(firstOp);
1414+
auto firstReadStmt = dyn_cast<AtomicReadOp>(firstOp);
1415+
auto firstUpdateStmt = dyn_cast<AtomicUpdateOp>(firstOp);
1416+
auto secondReadStmt = dyn_cast<AtomicReadOp>(secondOp);
1417+
auto secondUpdateStmt = dyn_cast<AtomicUpdateOp>(secondOp);
1418+
auto secondWriteStmt = dyn_cast<AtomicWriteOp>(secondOp);
1419+
1420+
if (!((firstUpdateStmt && secondReadStmt) ||
1421+
(firstReadStmt && secondUpdateStmt) ||
1422+
(firstReadStmt && secondWriteStmt)))
1423+
return emitError(ops.front().getLoc())
1424+
<< "invalid sequence of operations in the capture region";
1425+
if (firstUpdateStmt && secondReadStmt &&
1426+
firstUpdateStmt.x() != secondReadStmt.x())
1427+
return emitError(firstUpdateStmt.getLoc())
1428+
<< "updated variable in omp.atomic.update must be captured in "
1429+
"second operation";
1430+
if (firstReadStmt && secondUpdateStmt &&
1431+
firstReadStmt.x() != secondUpdateStmt.x())
1432+
return emitError(firstReadStmt.getLoc())
1433+
<< "captured variable in omp.atomic.read must be updated in "
1434+
"second operation";
1435+
if (firstReadStmt && secondWriteStmt &&
1436+
firstReadStmt.x() != secondWriteStmt.address())
1437+
return emitError(firstReadStmt.getLoc())
1438+
<< "captured variable in omp.atomic.read must be updated in "
1439+
"second operation";
13961440
return success();
13971441
}
13981442

0 commit comments

Comments
 (0)