Skip to content

Commit 105bcec

Browse files
committed
address reviewer comments and rebase
1 parent 5b1a2ad commit 105bcec

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,15 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
6161
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
6262
return ptrLike;
6363
Value md = fromPtr.getMetadata();
64-
// If there's no metadata in the op fold the op.
65-
if (!md)
66-
ptrLike = toPtr.getPtr();
67-
// Fold if the metadata can be verified to be equal.
68-
else if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
69-
mdOp && mdOp.getPtr() == toPtr.getPtr())
64+
// If the type has trivial metadata fold.
65+
if (!fromPtr.getType().hasPtrMetadata()) {
7066
ptrLike = toPtr.getPtr();
67+
} else if (md) {
68+
// Fold if the metadata can be verified to be equal.
69+
if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
70+
mdOp && mdOp.getPtr() == toPtr.getPtr())
71+
ptrLike = toPtr.getPtr();
72+
}
7173
// Check for a sequence of casts.
7274
fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
7375
: nullptr);

mlir/test/Dialect/Ptr/canonicalize.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ func.func @test_from_ptr_0(%mr: memref<f32, #ptr.generic_space>) -> memref<f32,
2828
return %res : memref<f32, #ptr.generic_space>
2929
}
3030

31+
/// Check the op doesn't fold because folding a ptr-type with metadata requires knowing the origin of the metadata.
3132
// CHECK-LABEL: @test_from_ptr_1
3233
// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
3334
func.func @test_from_ptr_1(%mr: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
34-
// CHECK-NOT: ptr.to_ptr
35-
// CHECK-NOT: ptr.from_ptr
36-
// CHECK: return %[[MEM_REF]]
35+
// CHECK: ptr.to_ptr
36+
// CHECK: ptr.from_ptr
3737
%ptr = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
3838
%res = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
3939
return %res : memref<f32, #ptr.generic_space>

0 commit comments

Comments
 (0)