@@ -139,26 +139,6 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
139
139
SmallVector<memref::GlobalOp> globalOps;
140
140
SmallVector<memref::GetGlobalOp> getGlobalOps;
141
141
142
- // Get containing symbol table op.
143
- auto symbolTableOps = state.getPayloadOps (getSymbolTable ());
144
- if (!llvm::hasSingleElement (symbolTableOps)) {
145
- return emitDefiniteFailure ()
146
- << Twine (" expected exactly one 'symbolTable' payload, but found " ) +
147
- std::to_string (llvm::range_size (symbolTableOps));
148
- }
149
- Operation *symbolTableOp = *symbolTableOps.begin ();
150
- if (!symbolTableOp->hasTrait <OpTrait::SymbolTable>()) {
151
- return emitDefiniteFailure () << Twine (
152
- " expected 'symbolTable' payload to have 'SymbolTable' trait" );
153
- }
154
- SymbolTable symbolTable (symbolTableOp);
155
-
156
- {
157
- size_t numAllocaOps = llvm::range_size (allocaOps);
158
- globalOps.reserve (numAllocaOps);
159
- getGlobalOps.reserve (numAllocaOps);
160
- }
161
-
162
142
// Transform `memref.alloca`s.
163
143
for (auto *op : allocaOps) {
164
144
auto alloca = cast<memref::AllocaOp>(op);
@@ -167,14 +147,15 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
167
147
168
148
memref::GlobalOp globalOp;
169
149
{
150
+ // Find nearest symbol table.
151
+ Operation *symbolTableOp = SymbolTable::getNearestSymbolTable (op);
152
+ assert (symbolTableOp && " expected alloca payload to be in symbol table" );
153
+ SymbolTable symbolTable (symbolTableOp);
154
+
170
155
// Insert a `memref.global` into the symbol table.
171
- if (symbolTable.getOp () != SymbolTable::getNearestSymbolTable (op)) {
172
- return emitDefiniteFailure () << " expected 'alloca' payload to be "
173
- " inside 'symbolTable' payload" ;
174
- }
175
156
Type resultType = alloca.getResult ().getType ();
176
- // TODO: Add a better builder for this.
177
157
OpBuilder builder (rewriter.getContext ());
158
+ // TODO: Add a better builder for this.
178
159
globalOp = builder.create <memref::GlobalOp>(
179
160
loc, StringAttr::get (ctx, " alloca" ), StringAttr::get (ctx, " private" ),
180
161
TypeAttr::get (resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
@@ -200,7 +181,6 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
200
181
201
182
void transform::MemRefAllocaToGlobalOp::getEffects (
202
183
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
203
- onlyReadsHandle (getSymbolTable (), effects);
204
184
producesHandle (getGlobal (), effects);
205
185
producesHandle (getGetGlobal (), effects);
206
186
consumesHandle (getAlloca (), effects);
0 commit comments