@@ -52,19 +52,28 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
52
52
// %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
53
53
// To:
54
54
// %val -> %v
55
- auto toPtr = dyn_cast_or_null<ToPtrOp>(getPtr ().getDefiningOp ());
56
- // Cannot fold if it's not a `to_ptr` op or the initial and final types are
57
- // different.
58
- if (!toPtr || toPtr.getPtr ().getType () != getType ())
59
- return nullptr ;
60
- Value md = getMetadata ();
61
- if (!md)
62
- return toPtr.getPtr ();
63
- // Fold if the metadata can be verified to be equal.
64
- if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp ());
65
- mdOp && mdOp.getPtr () == toPtr.getPtr ())
66
- return toPtr.getPtr ();
67
- return nullptr ;
55
+ Value ptrLike;
56
+ FromPtrOp fromPtr = *this ;
57
+ while (fromPtr != nullptr ) {
58
+ auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr ().getDefiningOp ());
59
+ // Cannot fold if it's not a `to_ptr` op or the initial and final types are
60
+ // different.
61
+ if (!toPtr || toPtr.getPtr ().getType () != fromPtr.getType ())
62
+ return ptrLike;
63
+ Value md = fromPtr.getMetadata ();
64
+ // If there's no metadata in the op, either the cast never requires metadata
65
+ // or the op has the trivial metadata flag set, therefore fold.
66
+ if (!md)
67
+ ptrLike = toPtr.getPtr ();
68
+ // Fold if the metadata can be verified to be equal.
69
+ else if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp ());
70
+ mdOp && mdOp.getPtr () == toPtr.getPtr ())
71
+ ptrLike = toPtr.getPtr ();
72
+ // Check for a sequence of casts.
73
+ fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp ()
74
+ : nullptr );
75
+ }
76
+ return ptrLike;
68
77
}
69
78
70
79
LogicalResult FromPtrOp::verify () {
@@ -113,11 +122,18 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
113
122
// %ptr = ptr.to_ptr %val : type -> ptr
114
123
// To:
115
124
// %ptr -> %p
116
- auto fromPtr = dyn_cast_or_null<FromPtrOp>(getPtr ().getDefiningOp ());
117
- // Cannot fold if it's not a `from_ptr` op.
118
- if (!fromPtr)
119
- return nullptr ;
120
- return fromPtr.getPtr ();
125
+ Value ptr;
126
+ ToPtrOp toPtr = *this ;
127
+ while (toPtr != nullptr ) {
128
+ auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr ().getDefiningOp ());
129
+ // Cannot fold if it's not a `from_ptr` op.
130
+ if (!fromPtr)
131
+ return ptr;
132
+ ptr = fromPtr.getPtr ();
133
+ // Check for chains of casts.
134
+ toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp ());
135
+ }
136
+ return ptr;
121
137
}
122
138
123
139
LogicalResult ToPtrOp::verify () {
0 commit comments