Skip to content

Commit 5668074

Browse files
committed
make folders work on cast sequences
1 parent 9c5c7b0 commit 5668074

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,28 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
5252
// %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
5353
// To:
5454
// %val -> %v
55-
auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr().getDefiningOp());
56-
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
57-
// different.
58-
if (!toPtr || toPtr.getPtr().getType() != getType())
59-
return nullptr;
60-
Value md = getMetadata();
61-
if (!md)
62-
return toPtr.getPtr();
63-
// Fold if the metadata can be verified to be equal.
64-
if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
65-
mdOp && mdOp.getPtr() == toPtr.getPtr())
66-
return toPtr.getPtr();
67-
return nullptr;
55+
Value ptrLike;
56+
FromPtrOp fromPtr = *this;
57+
while (fromPtr != nullptr) {
58+
auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
59+
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
60+
// different.
61+
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
62+
return ptrLike;
63+
Value md = fromPtr.getMetadata();
64+
// If there's no metadata in the op, either the cast never requires metadata
65+
// or the op has the trivial metadata flag set, therefore fold.
66+
if (!md)
67+
ptrLike = toPtr.getPtr();
68+
// Fold if the metadata can be verified to be equal.
69+
else if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
70+
mdOp && mdOp.getPtr() == toPtr.getPtr())
71+
ptrLike = toPtr.getPtr();
72+
// Check for a sequence of casts.
73+
fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
74+
: nullptr);
75+
}
76+
return ptrLike;
6877
}
6978

7079
LogicalResult FromPtrOp::verify() {
@@ -113,11 +122,18 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
113122
// %ptr = ptr.to_ptr %val : type -> ptr
114123
// To:
115124
// %ptr -> %p
116-
auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr().getDefiningOp());
117-
// Cannot fold if it's not a `from_ptr` op.
118-
if (!fromPtr)
119-
return nullptr;
120-
return fromPtr.getPtr();
125+
Value ptr;
126+
ToPtrOp toPtr = *this;
127+
while (toPtr != nullptr) {
128+
auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
129+
// Cannot fold if it's not a `from_ptr` op.
130+
if (!fromPtr)
131+
return ptr;
132+
ptr = fromPtr.getPtr();
133+
// Check for chains of casts.
134+
toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
135+
}
136+
return ptr;
121137
}
122138

123139
LogicalResult ToPtrOp::verify() {

0 commit comments

Comments
 (0)