Skip to content

Commit a474d19

Browse files
committed
Add attribute to MemRef/Vector memory access ops
1 parent 01f9dff commit a474d19

File tree

9 files changed

+352
-5
lines changed

9 files changed

+352
-5
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,45 @@ def LoadOp : MemRef_Op<"load",
12271227
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
12281228
[MemRead]>:$memref,
12291229
Variadic<Index>:$indices,
1230-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1230+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1231+
ConfinedAttr<OptionalAttr<I32Attr>,
1232+
[IntPositive]>:$alignment);
1233+
1234+
let builders = [
1235+
OpBuilder<(ins "Value":$memref,
1236+
"ValueRange":$indices,
1237+
CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
1238+
return build($_builder, $_state, memref, indices, false, alignment);
1239+
}]>,
1240+
OpBuilder<(ins "Value":$memref,
1241+
"ValueRange":$indices,
1242+
"bool":$nontemporal), [{
1243+
return build($_builder, $_state, memref, indices, nontemporal,
1244+
IntegerAttr());
1245+
}]>,
1246+
OpBuilder<(ins "Type":$resultType,
1247+
"Value":$memref,
1248+
"ValueRange":$indices,
1249+
CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
1250+
return build($_builder, $_state, resultType, memref, indices, false,
1251+
alignment);
1252+
}]>,
1253+
OpBuilder<(ins "Type":$resultType,
1254+
"Value":$memref,
1255+
"ValueRange":$indices,
1256+
"bool":$nontemporal), [{
1257+
return build($_builder, $_state, resultType, memref, indices, nontemporal,
1258+
IntegerAttr());
1259+
}]>,
1260+
OpBuilder<(ins "TypeRange":$resultTypes,
1261+
"Value":$memref,
1262+
"ValueRange":$indices,
1263+
CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
1264+
return build($_builder, $_state, resultTypes, memref, indices, false,
1265+
alignment);
1266+
}]>
1267+
];
1268+
12311269
let results = (outs AnyType:$result);
12321270

12331271
let extraClassDeclaration = [{
@@ -1924,13 +1962,30 @@ def MemRef_StoreOp : MemRef_Op<"store",
19241962
Arg<AnyMemRef, "the reference to store to",
19251963
[MemWrite]>:$memref,
19261964
Variadic<Index>:$indices,
1927-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1965+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1966+
ConfinedAttr<OptionalAttr<I32Attr>,
1967+
[IntPositive]>:$alignment);
19281968

19291969
let builders = [
1970+
OpBuilder<(ins "Value":$valueToStore,
1971+
"Value":$memref,
1972+
"ValueRange":$indices,
1973+
CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
1974+
return build($_builder, $_state, valueToStore, memref, indices, false,
1975+
alignment);
1976+
}]>,
1977+
OpBuilder<(ins "Value":$valueToStore,
1978+
"Value":$memref,
1979+
"ValueRange":$indices,
1980+
"bool":$nontemporal), [{
1981+
return build($_builder, $_state, valueToStore, memref, indices, nontemporal,
1982+
IntegerAttr());
1983+
}]>,
19301984
OpBuilder<(ins "Value":$valueToStore, "Value":$memref), [{
19311985
$_state.addOperands(valueToStore);
19321986
$_state.addOperands(memref);
1933-
}]>];
1987+
}]>
1988+
];
19341989

19351990
let extraClassDeclaration = [{
19361991
Value getValueToStore() { return getOperand(0); }

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1739,7 +1739,34 @@ def Vector_LoadOp : Vector_Op<"load"> {
17391739
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
17401740
[MemRead]>:$base,
17411741
Variadic<Index>:$indices,
1742-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1742+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1743+
ConfinedAttr<OptionalAttr<I32Attr>,
1744+
[IntPositive]>:$alignment);
1745+
1746+
let builders = [
1747+
OpBuilder<(ins "VectorType":$resultType,
1748+
"Value":$base,
1749+
"ValueRange":$indices,
1750+
CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
1751+
return build($_builder, $_state, resultType, base, indices, false,
1752+
alignment);
1753+
}]>,
1754+
OpBuilder<(ins "VectorType":$resultType,
1755+
"Value":$base,
1756+
"ValueRange":$indices,
1757+
"bool":$nontemporal), [{
1758+
return build($_builder, $_state, resultType, base, indices, nontemporal,
1759+
IntegerAttr());
1760+
}]>,
1761+
OpBuilder<(ins "TypeRange":$resultTypes,
1762+
"Value":$base,
1763+
"ValueRange":$indices,
1764+
CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
1765+
return build($_builder, $_state, resultTypes, base, indices, false,
1766+
alignment);
1767+
}]>
1768+
];
1769+
17431770
let results = (outs AnyVectorOfAnyRank:$result);
17441771

17451772
let extraClassDeclaration = [{
@@ -1825,9 +1852,28 @@ def Vector_StoreOp : Vector_Op<"store"> {
18251852
Arg<AnyMemRef, "the reference to store to",
18261853
[MemWrite]>:$base,
18271854
Variadic<Index>:$indices,
1828-
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal
1855+
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
1856+
ConfinedAttr<OptionalAttr<I32Attr>,
1857+
[IntPositive]>:$alignment
18291858
);
18301859

1860+
let builders = [
1861+
OpBuilder<(ins "Value":$valueToStore,
1862+
"Value":$base,
1863+
"ValueRange":$indices,
1864+
CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{
1865+
return build($_builder, $_state, valueToStore, base, indices, false,
1866+
alignment);
1867+
}]>,
1868+
OpBuilder<(ins "Value":$valueToStore,
1869+
"Value":$base,
1870+
"ValueRange":$indices,
1871+
"bool":$nontemporal), [{
1872+
return build($_builder, $_state, valueToStore, base, indices, nontemporal,
1873+
IntegerAttr());
1874+
}]>
1875+
];
1876+
18311877
let extraClassDeclaration = [{
18321878
MemRefType getMemRefType() {
18331879
return ::llvm::cast<MemRefType>(getBase().getType());
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK-LABEL: func @test_load_store_alignment
4+
// CHECK: memref.load {{.*}} {alignment = 16 : i32}
5+
// CHECK: memref.store {{.*}} {alignment = 16 : i32}
6+
func.func @test_load_store_alignment(%memref: memref<4xi32>) {
7+
%c0 = arith.constant 0 : index
8+
%val = memref.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
9+
memref.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>
10+
return
11+
}
12+
13+
// -----
14+
15+
func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
16+
// expected-error @+1 {{custom op 'memref.load' 'memref.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
17+
%val = memref.load %memref[%c0] { alignment = -1 } : memref<4xi32>
18+
return
19+
}
20+
21+
// -----
22+
23+
func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: memref<4xi32>) {
24+
// expected-error @+1 {{custom op 'memref.store' 'memref.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
25+
memref.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>
26+
return
27+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
// CHECK-LABEL: func @test_load_store_alignment
4+
// CHECK: vector.load {{.*}} {alignment = 16 : i32}
5+
// CHECK: vector.store {{.*}} {alignment = 16 : i32}
6+
func.func @test_load_store_alignment(%memref: memref<4xi32>) {
7+
%c0 = arith.constant 0 : index
8+
%val = vector.load %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
9+
vector.store %val, %memref[%c0] { alignment = 16 : i32 } : memref<4xi32>, vector<4xi32>
10+
return
11+
}
12+
13+
// -----
14+
15+
func.func @test_invalid_load_alignment(%memref: memref<4xi32>) {
16+
// expected-error @+1 {{custom op 'vector.load' 'vector.load' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
17+
%val = vector.load %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
18+
return
19+
}
20+
21+
// -----
22+
23+
func.func @test_invalid_store_alignment(%memref: memref<4xi32>, %val: vector<4xi32>) {
24+
// expected-error @+1 {{custom op 'vector.store' 'vector.store' op attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}}
25+
vector.store %val, %memref[%c0] { alignment = -1 } : memref<4xi32>, vector<4xi32>
26+
return
27+
}

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ add_subdirectory(SPIRV)
1818
add_subdirectory(SMT)
1919
add_subdirectory(Transform)
2020
add_subdirectory(Utils)
21+
add_subdirectory(Vector)

mlir/unittests/Dialect/MemRef/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_unittest(MLIRMemRefTests
22
InferShapeTest.cpp
3+
LoadStoreAlignment.cpp
34
)
45
mlir_target_link_libraries(MLIRMemRefTests
56
PRIVATE
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
//===- LoadStoreAlignment.cpp - unit tests for load/store alignment -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
10+
#include "mlir/IR/Builders.h"
11+
#include "mlir/IR/Verifier.h"
12+
#include "gtest/gtest.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::memref;
16+
17+
TEST(LoadStoreAlignmentTest, ValidAlignment) {
18+
MLIRContext ctx;
19+
OpBuilder b(&ctx);
20+
ctx.loadDialect<memref::MemRefDialect>();
21+
22+
// Create a dummy memref
23+
Type elementType = b.getI32Type();
24+
auto memrefType = MemRefType::get({4}, elementType);
25+
Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
26+
27+
// Create load with valid alignment
28+
Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
29+
IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), 16);
30+
auto loadOp =
31+
b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
32+
33+
// Verify the attribute exists
34+
auto alignmentAttr = loadOp->getAttrOfType<IntegerAttr>("alignment");
35+
EXPECT_TRUE(alignmentAttr != nullptr);
36+
EXPECT_EQ(alignmentAttr.getInt(), 16);
37+
38+
// Create store with valid alignment
39+
auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
40+
ValueRange{zero}, alignment);
41+
42+
// Verify the attribute exists
43+
alignmentAttr = storeOp->getAttrOfType<IntegerAttr>("alignment");
44+
EXPECT_TRUE(alignmentAttr != nullptr);
45+
EXPECT_EQ(alignmentAttr.getInt(), 16);
46+
}
47+
48+
TEST(LoadStoreAlignmentTest, InvalidAlignmentFailsVerification) {
49+
MLIRContext ctx;
50+
OpBuilder b(&ctx);
51+
ctx.loadDialect<memref::MemRefDialect>();
52+
53+
Type elementType = b.getI32Type();
54+
auto memrefType = MemRefType::get({4}, elementType);
55+
Value memref = b.create<memref::AllocaOp>(b.getUnknownLoc(), memrefType);
56+
57+
Value zero = b.create<arith::ConstantIndexOp>(b.getUnknownLoc(), 0);
58+
IntegerAttr alignment = IntegerAttr::get(IntegerType::get(&ctx, 32), -1);
59+
60+
auto loadOp =
61+
b.create<LoadOp>(b.getUnknownLoc(), memref, ValueRange{zero}, alignment);
62+
63+
// Capture diagnostics
64+
std::string errorMessage;
65+
ScopedDiagnosticHandler handler(
66+
&ctx, [&](Diagnostic &diag) { errorMessage = diag.str(); });
67+
68+
// Trigger verification
69+
auto result = mlir::verify(loadOp);
70+
71+
// Check results
72+
EXPECT_TRUE(failed(result));
73+
EXPECT_EQ(
74+
errorMessage,
75+
"'memref.load' op attribute 'alignment' failed to satisfy constraint: "
76+
"32-bit signless integer attribute whose value is positive");
77+
78+
auto storeOp = b.create<StoreOp>(b.getUnknownLoc(), loadOp, memref,
79+
ValueRange{zero}, alignment);
80+
result = mlir::verify(storeOp);
81+
82+
// Check results
83+
EXPECT_TRUE(failed(result));
84+
EXPECT_EQ(
85+
errorMessage,
86+
"'memref.store' op attribute 'alignment' failed to satisfy constraint: "
87+
"32-bit signless integer attribute whose value is positive");
88+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
add_mlir_unittest(MLIRVectorTests
2+
LoadStoreAlignment.cpp
3+
)
4+
mlir_target_link_libraries(MLIRVectorTests
5+
PRIVATE
6+
MLIRVectorDialect
7+
)

0 commit comments

Comments
 (0)