Skip to content

Commit a91c93d

Browse files
Remove symbolTable op arg and simplify tests.
1 parent b8f439e commit a91c93d

File tree

3 files changed

+16
-47
lines changed

3 files changed

+16
-47
lines changed

mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def MemRefAllocaToGlobalOp :
153153
DeclareOpInterfaceMethods<TransformOpInterface>]> {
154154
let description = [{
155155
Inserts a new `memref.global` for each provided `memref.alloca` into the
156-
provided symbol table (e.g., a `builtin.module`) and replaces it with a
156+
nearest symbol table (e.g., a `builtin.module`) and replaces it with a
157157
`memref.get_global`. This is useful, for example, for allocations that
158158
should reside in the shared memory of a GPU, which have to be declared as
159159
globals.
@@ -164,8 +164,8 @@ def MemRefAllocaToGlobalOp :
164164

165165
```mlir
166166
%get_global, %global =
167-
transform.memref.alloca_to_global %alloca in %module
168-
: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
167+
transform.memref.alloca_to_global %alloca
168+
: (!transform.op<"memref.alloca">)
169169
-> (!transform.any_op, !transform.any_op)
170170
```
171171

@@ -195,20 +195,16 @@ def MemRefAllocaToGlobalOp :
195195

196196
#### Return modes
197197

198-
Emits a definite failure if not exactly one symbol table payload op was
199-
provided or any of the `alloca` payload ops is not inside that symbol table
200-
op, and succeeds otherwise. The returned handles refer to the
201-
`memref.get_global` and `memref.global` ops that were inserted by the
202-
transformation.
198+
Succeeds always. The returned handles refer to the `memref.get_global` and
199+
`memref.global` ops that were inserted by the transformation.
203200
}];
204201

205-
let arguments = (ins TransformHandleTypeInterface:$symbolTable,
206-
Transform_MemRefAllocaOp:$alloca);
202+
let arguments = (ins Transform_MemRefAllocaOp:$alloca);
207203
let results = (outs TransformHandleTypeInterface:$getGlobal,
208204
TransformHandleTypeInterface:$global);
209205

210206
let assemblyFormat = [{
211-
$alloca `in` $symbolTable attr-dict `:` functional-type(operands, results)
207+
$alloca attr-dict `:` functional-type(operands, results)
212208
}];
213209
}
214210

mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -139,26 +139,6 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
139139
SmallVector<memref::GlobalOp> globalOps;
140140
SmallVector<memref::GetGlobalOp> getGlobalOps;
141141

142-
// Get containing symbol table op.
143-
auto symbolTableOps = state.getPayloadOps(getSymbolTable());
144-
if (!llvm::hasSingleElement(symbolTableOps)) {
145-
return emitDefiniteFailure()
146-
<< Twine("expected exactly one 'symbolTable' payload, but found ") +
147-
std::to_string(llvm::range_size(symbolTableOps));
148-
}
149-
Operation *symbolTableOp = *symbolTableOps.begin();
150-
if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
151-
return emitDefiniteFailure() << Twine(
152-
"expected 'symbolTable' payload to have 'SymbolTable' trait");
153-
}
154-
SymbolTable symbolTable(symbolTableOp);
155-
156-
{
157-
size_t numAllocaOps = llvm::range_size(allocaOps);
158-
globalOps.reserve(numAllocaOps);
159-
getGlobalOps.reserve(numAllocaOps);
160-
}
161-
162142
// Transform `memref.alloca`s.
163143
for (auto *op : allocaOps) {
164144
auto alloca = cast<memref::AllocaOp>(op);
@@ -167,14 +147,15 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
167147

168148
memref::GlobalOp globalOp;
169149
{
150+
// Find nearest symbol table.
151+
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
152+
assert(symbolTableOp && "expected alloca payload to be in symbol table");
153+
SymbolTable symbolTable(symbolTableOp);
154+
170155
// Insert a `memref.global` into the symbol table.
171-
if (symbolTable.getOp() != SymbolTable::getNearestSymbolTable(op)) {
172-
return emitDefiniteFailure() << "expected 'alloca' payload to be "
173-
"inside 'symbolTable' payload";
174-
}
175156
Type resultType = alloca.getResult().getType();
176-
// TODO: Add a better builder for this.
177157
OpBuilder builder(rewriter.getContext());
158+
// TODO: Add a better builder for this.
178159
globalOp = builder.create<memref::GlobalOp>(
179160
loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
180161
TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
@@ -200,7 +181,6 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
200181

201182
void transform::MemRefAllocaToGlobalOp::getEffects(
202183
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
203-
onlyReadsHandle(getSymbolTable(), effects);
204184
producesHandle(getGlobal(), effects);
205185
producesHandle(getGetGlobal(), effects);
206186
consumesHandle(getAlloca(), effects);

mlir/test/Dialect/MemRef/transform-ops.mlir

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,9 @@ func.func @func(%lb: index, %ub: index) {
2323
transform.sequence failures(propagate) {
2424
^bb1(%arg0: !transform.any_op):
2525
%alloca = transform.structured.match ops{["memref.alloca"]} in %arg0
26-
: (!transform.any_op) -> !transform.any_op
27-
%module = transform.structured.match ops{["builtin.module"]} in %arg0
28-
: (!transform.any_op) -> !transform.any_op
29-
%alloca_typed = transform.cast %alloca
30-
: !transform.any_op to !transform.op<"memref.alloca">
31-
%module_typed = transform.cast %module
32-
: !transform.any_op to !transform.op<"builtin.module">
33-
%get_global, %global =
34-
transform.memref.alloca_to_global %alloca_typed in %module_typed
35-
: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">)
26+
: (!transform.any_op) -> !transform.op<"memref.alloca">
27+
%get_global, %global = transform.memref.alloca_to_global %alloca
28+
: (!transform.op<"memref.alloca">)
3629
-> (!transform.any_op, !transform.any_op)
3730
}
3831

0 commit comments

Comments
 (0)