Skip to content

[mlir][core|ptr] Add PtrLikeTypeInterface and casting ops to the ptr dialect #137469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>

def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
MemRefElementTypeInterface,
PtrLikeTypeInterface,
VectorElementTypeInterface,
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"areCompatible", "getIndexBitwidth", "verifyEntries",
Expand All @@ -63,6 +64,55 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
return $_get(memorySpace.getContext(), memorySpace);
}]>
];
let extraClassDeclaration = [{
// `PtrLikeTypeInterface` interface methods.
/// Returns `Type()` as this pointer type is opaque.
Type getElementType() const {
return Type();
}
/// Clones the pointer with specified memory space or returns failure
/// if an `elementType` was specified or if the memory space doesn't
/// implement `MemorySpaceAttrInterface`.
FailureOr<PtrLikeTypeInterface> clonePtrWith(Attribute memorySpace,
std::optional<Type> elementType) const {
if (elementType)
return failure();
if (auto ms = dyn_cast<MemorySpaceAttrInterface>(memorySpace))
return cast<PtrLikeTypeInterface>(get(ms));
return failure();
}
/// `!ptr.ptr` types are seen as ptr-like objects with no metadata.
bool hasPtrMetadata() const {
return false;
}
}];
}

def Ptr_PtrMetadata : Ptr_Type<"PtrMetadata", "ptr_metadata"> {
let summary = "Pointer metadata type";
let description = [{
The `ptr_metadata` type represents an opaque-view of the metadata associated
with a `ptr-like` object type.

Note: It's a verification error to construct a `ptr_metadata` type using a
`ptr-like` type with no metadata.

Example:

```mlir
// The metadata associated with a `memref` type.
!ptr.ptr_metadata<memref<f32>>
```
}];
let parameters = (ins "PtrLikeTypeInterface":$type);
let assemblyFormat = "`<` $type `>`";
let builders = [
TypeBuilderWithInferredContext<(ins
"PtrLikeTypeInterface":$ptrLike), [{
return $_get(ptrLike.getContext(), ptrLike);
}]>
];
let genVerifyDecl = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
100 changes: 98 additions & 2 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,72 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/OpAsmInterface.td"

//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//

def Ptr_FromPtrOp : Pointer_Op<"from_ptr", [
Pure, OptionalTypesMatchWith<"metadata type", "result", "metadata",
"PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
]> {
let summary = "Casts a `!ptr.ptr` value to a ptr-like value.";
let description = [{
The `from_ptr` operation casts a `ptr` value to a ptr-like object. It's
important to note that:
- The ptr-like object cannot be a `!ptr.ptr`.
- The memory-space of both the `ptr` and ptr-like object must match.
- The cast is Pure (no UB and side-effect free).

The optional `metadata` operand exists to provide any ptr-like metadata
that might be required to perform the cast.

Example:

```mlir
%typed_ptr = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> !my.ptr<f32, #ptr.generic_space>
%memref = ptr.from_ptr %ptr metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>

// Cast the `%ptr` to a memref without utilizing metadata.
%memref = ptr.from_ptr %ptr : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
```
}];

let arguments = (ins Ptr_PtrType:$ptr, Optional<Ptr_PtrMetadata>:$metadata);
let results = (outs PtrLikeTypeInterface:$result);
let assemblyFormat = [{
$ptr (`metadata` $metadata^)? attr-dict `:` type($ptr) `->` type($result)
}];
let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// GetMetadataOp
//===----------------------------------------------------------------------===//

def Ptr_GetMetadataOp : Pointer_Op<"get_metadata", [
Pure, TypesMatchWith<"metadata type", "ptr", "result",
"PtrMetadataType::get(cast<PtrLikeTypeInterface>($_self))">
]> {
let summary = "SSA value representing pointer metadata.";
let description = [{
The `get_metadata` operation produces an opaque value that encodes the
metadata of the ptr-like type.

Example:

```mlir
%metadata = ptr.get_metadata %memref : memref<?x?xf32>
```
}];

let arguments = (ins PtrLikeTypeInterface:$ptr);
let results = (outs Ptr_PtrMetadata:$result);
let assemblyFormat = [{
$ptr attr-dict `:` type($ptr)
}];
}

//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
Expand All @@ -32,8 +98,8 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
Example:

```mlir
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<0>, i32
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<0>, i32
%x_off = ptr.ptr_add %x, %off : !ptr.ptr<#ptr.generic_space>, i32
%x_off0 = ptr.ptr_add nusw %x, %off : !ptr.ptr<#ptr.generic_space>, i32
```
}];

Expand All @@ -52,6 +118,36 @@ def Ptr_PtrAddOp : Pointer_Op<"ptr_add", [
}];
}

//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//

def Ptr_ToPtrOp : Pointer_Op<"to_ptr", [Pure]> {
let summary = "Casts a ptr-like value to a `!ptr.ptr` value.";
let description = [{
The `to_ptr` operation casts a ptr-like object to a `!ptr.ptr`. It's
important to note that:
- The ptr-like object cannot be a `!ptr.ptr`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but I wonder why this restriction is needed. We could allow it and have the op fold away. This may fall out of type conversion during progressive lowering and will be annoying to handle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that any !ptr.ptr to !ptr.ptr conversions must happen through other operations, for example, memory space cast. This allows to put a barrier around acceptable casting semantics inside the dialect. The boundary of this barrier is those ops that allow to go to external ptr-like types.

- The memory-space of both the `ptr` and ptr-like object must match.
- The cast is side-effect free.

Example:

```mlir
%ptr0 = ptr.to_ptr %my_ptr : !my.ptr<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
%ptr1 = ptr.to_ptr %memref : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
```
}];

let arguments = (ins PtrLikeTypeInterface:$ptr);
let results = (outs Ptr_PtrType:$result);
let assemblyFormat = [{
$ptr attr-dict `:` type($ptr) `->` type($result)
}];
let hasFolder = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//
Expand Down
53 changes: 53 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,59 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
}];
}

//===----------------------------------------------------------------------===//
// PtrLikeTypeInterface
//===----------------------------------------------------------------------===//

def PtrLikeTypeInterface : TypeInterface<"PtrLikeTypeInterface"> {
let cppNamespace = "::mlir";
let description = [{
A ptr-like type represents an object storing a memory address. This object
is constituted by:
- A memory address called the base pointer. This pointer is treated as a
bag of bits without any assumed structure. The bit-width of the base
pointer must be a compile-time constant. However, the bit-width may remain
opaque or unavailable during transformations that do not depend on the
base pointer. Finally, it is considered indivisible in the sense that as
a `PtrLikeTypeInterface` value, it has no metadata.
- Optional metadata about the pointer. For example, the size of the memory
region associated with the pointer.

Furthermore, all ptr-like types have two properties:
- The memory space associated with the address held by the pointer.
- An optional element type. If the element type is not specified, the
pointer is considered opaque.
}];
let methods = [
InterfaceMethod<[{
Returns the memory space of this ptr-like type.
}],
"::mlir::Attribute", "getMemorySpace">,
InterfaceMethod<[{
Returns the element type of this ptr-like type. Note: this method can
return `::mlir::Type()`, in which case the pointer is considered opaque.
}],
"::mlir::Type", "getElementType">,
InterfaceMethod<[{
Returns whether this ptr-like type has non-empty metadata.
}],
"bool", "hasPtrMetadata">,
Comment on lines +146 to +149
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit for bikeshedding, not blocking for PR: are we afraid of empty metadata because that would be equivalent to having a value of void type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's one of the main reasons. Also, from a impl point of view it allows having the !ptr.ptr_metadata verifier.

InterfaceMethod<[{
Returns a clone of this type with the given memory space and element type,
or `failure` if the type cannot be cloned with the specified arguments.
If the pointer is opaque and `elementType` is not `std::nullopt` the
method will return `failure`.

If no `elementType` is provided and ptr is not opaque, the `elementType`
of this type is used.
}],
"::llvm::FailureOr<::mlir::PtrLikeTypeInterface>", "clonePtrWith", (ins
"::mlir::Attribute":$memorySpace,
"::std::optional<::mlir::Type>":$elementType
)>
];
}

//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
Expand Down
18 changes: 17 additions & 1 deletion mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class TensorType : public Type, public ShapedType::Trait<TensorType> {
/// Note: This class attaches the ShapedType trait to act as a mixin to
/// provide many useful utility functions. This inheritance has no effect
/// on derived memref types.
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
class BaseMemRefType : public Type,
public PtrLikeTypeInterface::Trait<BaseMemRefType>,
public ShapedType::Trait<BaseMemRefType> {
public:
using Type::Type;

Expand All @@ -117,6 +119,12 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;

/// Clone this type with the given memory space and element type. If the
/// provided element type is `std::nullopt`, the current element type of the
/// type is used.
FailureOr<PtrLikeTypeInterface>
clonePtrWith(Attribute memorySpace, std::optional<Type> elementType) const;

// Make sure that base class overloads are visible.
using ShapedType::Trait<BaseMemRefType>::clone;

Expand All @@ -141,8 +149,16 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;

/// Returns that this ptr-like object has non-empty ptr metadata.
bool hasPtrMetadata() const { return true; }

/// Allow implicit conversion to ShapedType.
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }

/// Allow implicit conversion to PtrLikeTypeInterface.
operator PtrLikeTypeInterface() const {
return llvm::cast<PtrLikeTypeInterface>(*this);
}
};

} // namespace mlir
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
//===----------------------------------------------------------------------===//

def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
PtrLikeTypeInterface,
ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference to a region of memory";
Expand Down Expand Up @@ -1143,6 +1144,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
//===----------------------------------------------------------------------===//

def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
PtrLikeTypeInterface,
ShapedTypeInterface
], "BaseMemRefType"> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
Expand Down
80 changes: 80 additions & 0 deletions mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,52 @@ void PtrDialect::initialize() {
>();
}

//===----------------------------------------------------------------------===//
// FromPtrOp
//===----------------------------------------------------------------------===//

OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) {
// Fold the pattern:
// %ptr = ptr.to_ptr %v : type -> ptr
// (%mda = ptr.get_metadata %v : type)?
// %val = ptr.from_ptr %ptr (metadata %mda)? : ptr -> type
// To:
// %val -> %v
Value ptrLike;
FromPtrOp fromPtr = *this;
while (fromPtr != nullptr) {
auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp());
// Cannot fold if it's not a `to_ptr` op or the initial and final types are
// different.
if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType())
return ptrLike;
Value md = fromPtr.getMetadata();
// If the type has trivial metadata fold.
if (!fromPtr.getType().hasPtrMetadata()) {
ptrLike = toPtr.getPtr();
} else if (md) {
// Fold if the metadata can be verified to be equal.
if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp());
mdOp && mdOp.getPtr() == toPtr.getPtr())
ptrLike = toPtr.getPtr();
}
// Check for a sequence of casts.
fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp()
: nullptr);
}
return ptrLike;
}

LogicalResult FromPtrOp::verify() {
if (isa<PtrType>(getType()))
return emitError() << "the result type cannot be `!ptr.ptr`";
if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
return emitError()
<< "expected the input and output to have the same memory space";
}
return success();
}

//===----------------------------------------------------------------------===//
// PtrAddOp
//===----------------------------------------------------------------------===//
Expand All @@ -55,6 +101,40 @@ OpFoldResult PtrAddOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// ToPtrOp
//===----------------------------------------------------------------------===//

OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) {
// Fold the pattern:
// %val = ptr.from_ptr %p (metadata ...)? : ptr -> type
// %ptr = ptr.to_ptr %val : type -> ptr
// To:
// %ptr -> %p
Value ptr;
ToPtrOp toPtr = *this;
while (toPtr != nullptr) {
auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp());
// Cannot fold if it's not a `from_ptr` op.
if (!fromPtr)
return ptr;
ptr = fromPtr.getPtr();
// Check for chains of casts.
toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp());
}
return ptr;
}

LogicalResult ToPtrOp::verify() {
if (isa<PtrType>(getPtr().getType()))
return emitError() << "the input value cannot be of type `!ptr.ptr`";
if (getType().getMemorySpace() != getPtr().getType().getMemorySpace()) {
return emitError()
<< "expected the input and output to have the same memory space";
}
return success();
}

//===----------------------------------------------------------------------===//
// TypeOffsetOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading