Skip to content

[mlir] Walk nested non-symbol table ops in symbol dce #143353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jpienaar
Copy link
Member

@jpienaar jpienaar commented Jun 9, 2025

The previous positioning was effectively that a symbol is dead if it cannot be addressed from top level. I think that is too strong a requirement: one can have operations that one cannot delete/DCE that refers to symbols which one could delete. This resulted in symbol-dce deleting symbols that are still referenced and the resulting IR being invalid.

This instead treats all the symbols of top level operations of non-symbol table ops additionally, as those are either dead and DCE would have handled, or alive and we cannot just delete symbols referenced internally. E.g., this treats non-symbol table regioned ops more conservatively.

The previous positioning was effectively that a symbol is dead if it
cannot be addressed from top level. I think that is too strong a
requirement: one can have operations that one cannot delete/DCE that
refers to symbols which one could delete. This resulted in symbol-dce
deleting symbols that are still referenced and the resulting IR being
invalid.

This instead all the symbols of top level operations of non-symbol table
ops additionally, as those are either dead and DCE would have handled,
or alive and we cannot just delete symbols referenced internally. E.g.,
this treats non-symbol table regioned ops more conservatively.
@jpienaar jpienaar requested review from ftynse and joker-eph June 9, 2025 07:54
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 9, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 9, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Jacques Pienaar (jpienaar)

Changes

The previous positioning was effectively that a symbol is dead if it cannot be addressed from top level. I think that is too strong a requirement: one can have operations that one cannot delete/DCE that refers to symbols which one could delete. This resulted in symbol-dce deleting symbols that are still referenced and the resulting IR being invalid.

This instead treats all the symbols of top level operations of non-symbol table ops additionally, as those are either dead and DCE would have handled, or alive and we cannot just delete symbols referenced internally. E.g., this treats non-symbol table regioned ops more conservatively.


Full diff: https://github.com/llvm/llvm-project/pull/143353.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/SymbolDCE.cpp (+47-2)
  • (modified) mlir/test/Transforms/test-symbol-dce.mlir (+19)
diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp
index 93d9a6547883a..52b4d06c98e32 100644
--- a/mlir/lib/Transforms/SymbolDCE.cpp
+++ b/mlir/lib/Transforms/SymbolDCE.cpp
@@ -22,6 +22,8 @@ namespace mlir {
 
 using namespace mlir;
 
+#define DEBUG_TYPE "symbol-dce"
+
 namespace {
 struct SymbolDCE : public impl::SymbolDCEBase<SymbolDCE> {
   void runOnOperation() override;
@@ -84,6 +86,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
                                          SymbolTableCollection &symbolTable,
                                          bool symbolTableIsHidden,
                                          DenseSet<Operation *> &liveSymbols) {
+  LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName()
+                          << "\n");
   // A worklist of live operations to propagate uses from.
   SmallVector<Operation *, 16> worklist;
 
@@ -108,6 +112,7 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
   // that are referenced within.
   while (!worklist.empty()) {
     Operation *op = worklist.pop_back_val();
+    LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n");
 
     // If this is a symbol table, recursively compute its liveness.
     if (op->hasTrait<OpTrait::SymbolTable>()) {
@@ -115,8 +120,34 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
       // symbol, or if it is a private symbol.
       SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
       bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
+      LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName()
+                              << " is hidden: " << symIsHidden << "\n");
       if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
         return failure();
+    } else {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "\tnon-symbol table: " << op->getName() << " is hidden\n");
+      // If the op is not a symbol table, then, unless op itself is dead which
+      // would be handled by DCE, we need to check all the regions and blocks
+      // within the op to find the uses (e.g., consider visibility within op as
+      // if top level rather than relying on pure symbol table visibility). This
+      // is more conservative than SymbolTable::walkSymbolTables in the case
+      // where there is again SymbolTable information to take advantage of.
+      for (auto &region : op->getRegions()) {
+        for (auto &block : region.getBlocks()) {
+          for (Operation &op : block) {
+            SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
+            if (!symbol) {
+              worklist.push_back(&op);
+              continue;
+            }
+            bool isDiscardable =
+                symbol.isPrivate() && symbol.canDiscardOnUseEmpty();
+            if (!isDiscardable && liveSymbols.insert(&op).second)
+              worklist.push_back(&op);
+          }
+        }
+      }
     }
 
     // Collect the uses held by this operation.
@@ -128,13 +159,27 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
     }
 
     SmallVector<Operation *, 4> resolvedSymbols;
+    // Get the first parent symbol table op.
+    Operation *parentOp = op->getParentOp();
+    while (parentOp && !parentOp->hasTrait<OpTrait::SymbolTable>()) {
+      parentOp = parentOp->getParentOp();
+    }
+    assert(parentOp && "operation has no parent symbol table");
+
+    LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n");
     for (const SymbolTable::SymbolUse &use : *uses) {
+      LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n");
       // Lookup the symbols referenced by this use.
       resolvedSymbols.clear();
-      if (failed(symbolTable.lookupSymbolIn(
-              op->getParentOp(), use.getSymbolRef(), resolvedSymbols)))
+      if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(),
+                                            resolvedSymbols)))
         // Ignore references to unknown symbols.
         continue;
+      LLVM_DEBUG({
+        llvm::dbgs() << "\t\tresolved symbols: ";
+        llvm::interleaveComma(resolvedSymbols, llvm::dbgs());
+        llvm::dbgs() << "\n";
+      });
 
       // Mark each of the resolved symbols as live.
       for (Operation *resolvedSymbol : resolvedSymbols)
diff --git a/mlir/test/Transforms/test-symbol-dce.mlir b/mlir/test/Transforms/test-symbol-dce.mlir
index 7bd784928e6f3..d44af1b93d241 100644
--- a/mlir/test/Transforms/test-symbol-dce.mlir
+++ b/mlir/test/Transforms/test-symbol-dce.mlir
@@ -98,3 +98,22 @@ module {
   // CHECK: "live.user"() {uses = [@unknown_symbol]} : () -> ()
   "live.user"() {uses = [@unknown_symbol]} : () -> ()
 }
+
+// -----
+
+// Check that we don't DCE nested symbols if they are used even if nested inside
+// an unnamed region.
+// CHECK-LABEL: module attributes {test.nested_unnamed_region}
+module attributes {test.nested_unnamed_region} {
+  "test.one_region_op"() ({
+    "test.symbol_scope"() ({
+      // CHECK: func @nested_function
+      func.func @nested_function() {
+        return
+      }
+      func.call @nested_function() : () -> ()
+      "test.finish"() : () -> ()
+    }) : () -> ()
+    "test.finish"() : () -> ()
+  }) : () -> ()
+}

@@ -22,6 +22,8 @@ namespace mlir {

using namespace mlir;

#define DEBUG_TYPE "symbol-dce"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IWYU nit: add #include "llvm/Support/Debug.h"

for (auto &region : op->getRegions()) {
for (auto &block : region.getBlocks()) {
for (Operation &op : block) {
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm mildly confused, but this is mostly for my education, not a request to change anything: isn't a symbol operation defined by the property that it "resides immediately within a region that defines a SymbolTable"? [1] If that is the case, it's not clear to me why this cast would ever trigger (the else block this path lies within implies the parent operation does not define a symbol table)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(the else block this path lies within implies the parent operation does not define a symbol table)

Where do we check the parent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 118

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I got confused by the shadowing or using op in the loop line 138...
That's not great: we should use another name here.

for (Operation &op : block) {
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
if (!symbol) {
worklist.push_back(&op);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That breaks the description of the loop I think?

   // Process the set of symbols that were known to be live, adding new symbols
   // that are referenced within.

Operation *parentOp = op->getParentOp();
while (parentOp && !parentOp->hasTrait<OpTrait::SymbolTable>()) {
parentOp = parentOp->getParentOp();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about this code-path.
Shouldn't we expect a symbol here? And so shouldn't the immediate parent be a symbol-table?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants