Skip to content

Commit 9744d39

Browse files
author
Jeff Niu
authored
[mlir][index] Implement folders for CastSOp and CastUOp (#66960)
Fixes #66402
1 parent b967f3a commit 9744d39

File tree

6 files changed

+191
-2
lines changed

6 files changed

+191
-2
lines changed

mlir/include/mlir/Dialect/Index/IR/IndexOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def Index_XOrOp : IndexBinaryOp<"xor", [Commutative, Pure]> {
446446
// CastSOp
447447
//===----------------------------------------------------------------------===//
448448

449-
def Index_CastSOp : IndexOp<"casts", [Pure,
449+
def Index_CastSOp : IndexOp<"casts", [Pure,
450450
DeclareOpInterfaceMethods<CastOpInterface>]> {
451451
let summary = "index signed cast";
452452
let description = [{
@@ -469,13 +469,14 @@ def Index_CastSOp : IndexOp<"casts", [Pure,
469469
let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
470470
let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
471471
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
472+
let hasFolder = 1;
472473
}
473474

474475
//===----------------------------------------------------------------------===//
475476
// CastUOp
476477
//===----------------------------------------------------------------------===//
477478

478-
def Index_CastUOp : IndexOp<"castu", [Pure,
479+
def Index_CastUOp : IndexOp<"castu", [Pure,
479480
DeclareOpInterfaceMethods<CastOpInterface>]> {
480481
let summary = "index unsigned cast";
481482
let description = [{
@@ -498,6 +499,7 @@ def Index_CastUOp : IndexOp<"castu", [Pure,
498499
let arguments = (ins AnyTypeOf<[AnyInteger, Index]>:$input);
499500
let results = (outs AnyTypeOf<[AnyInteger, Index]>:$output);
500501
let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
502+
let hasFolder = 1;
501503
}
502504

503505
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,11 +444,63 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
444444
// CastSOp
445445
//===----------------------------------------------------------------------===//
446446

447+
static OpFoldResult
448+
foldCastOp(Attribute input, Type type,
449+
function_ref<APInt(const APInt &, unsigned)> extFn,
450+
function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
451+
auto attr = dyn_cast_if_present<IntegerAttr>(input);
452+
if (!attr)
453+
return {};
454+
const APInt &value = attr.getValue();
455+
456+
if (isa<IndexType>(type)) {
457+
// When casting to an index type, perform the cast assuming a 64-bit target.
458+
// The result can be truncated to 32 bits as needed and always be correct.
459+
// This is because `cast32(cast64(value)) == cast32(value)`.
460+
APInt result = extOrTruncFn(value, 64);
461+
return IntegerAttr::get(type, result);
462+
}
463+
464+
// When casting from an index type, we must ensure the results respect
465+
// `cast_t(value) == cast_t(trunc32(value))`.
466+
auto intType = cast<IntegerType>(type);
467+
unsigned width = intType.getWidth();
468+
469+
// If the result type is at most 32 bits, then the cast can always be folded
470+
// because it is always a truncation.
471+
if (width <= 32) {
472+
APInt result = value.trunc(width);
473+
return IntegerAttr::get(type, result);
474+
}
475+
476+
// If the result type is at least 64 bits, then the cast is always a
477+
// extension. The results will differ if `trunc32(value) != value)`.
478+
if (width >= 64) {
479+
if (extFn(value.trunc(32), 64) != value)
480+
return {};
481+
APInt result = extFn(value, width);
482+
return IntegerAttr::get(type, result);
483+
}
484+
485+
// Otherwise, we just have to check the property directly.
486+
APInt result = value.trunc(width);
487+
if (result != extFn(value.trunc(32), width))
488+
return {};
489+
return IntegerAttr::get(type, result);
490+
}
491+
447492
bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
448493
return llvm::isa<IndexType>(lhsTypes.front()) !=
449494
llvm::isa<IndexType>(rhsTypes.front());
450495
}
451496

497+
OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
498+
return foldCastOp(
499+
adaptor.getInput(), getType(),
500+
[](const APInt &x, unsigned width) { return x.sext(width); },
501+
[](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
502+
}
503+
452504
//===----------------------------------------------------------------------===//
453505
// CastUOp
454506
//===----------------------------------------------------------------------===//
@@ -458,6 +510,13 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
458510
llvm::isa<IndexType>(rhsTypes.front());
459511
}
460512

513+
OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
514+
return foldCastOp(
515+
adaptor.getInput(), getType(),
516+
[](const APInt &x, unsigned width) { return x.zext(width); },
517+
[](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
518+
}
519+
461520
//===----------------------------------------------------------------------===//
462521
// CmpOp
463522
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Index/index-canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,3 +556,19 @@ func.func @sub_identity(%arg0: index) -> index {
556556
// CHECK-NEXT: return %arg0
557557
return %0 : index
558558
}
559+
560+
// CHECK-LABEL: @castu_to_index
561+
func.func @castu_to_index() -> index {
562+
// CHECK: index.constant 8000000000000
563+
%0 = arith.constant 8000000000000 : i48
564+
%1 = index.castu %0 : i48 to index
565+
return %1 : index
566+
}
567+
568+
// CHECK-LABEL: @casts_to_index
569+
func.func @casts_to_index() -> index {
570+
// CHECK: index.constant -1000
571+
%0 = arith.constant -1000 : i48
572+
%1 = index.casts %0 : i48 to index
573+
return %1 : index
574+
}

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
66
MLIRIR
77
MLIRDialect)
88

9+
add_subdirectory(Index)
910
add_subdirectory(LLVMIR)
1011
add_subdirectory(MemRef)
1112
add_subdirectory(SparseTensor)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
add_mlir_unittest(MLIRIndexOpsTests
2+
IndexOpsFoldersTest.cpp
3+
)
4+
target_link_libraries(MLIRIndexOpsTests
5+
PRIVATE
6+
MLIRIndexDialect
7+
)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
//===- IndexOpsFoldersTest.cpp - unit tests for index op folders ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Index/IR/IndexDialect.h"
10+
#include "mlir/Dialect/Index/IR/IndexOps.h"
11+
#include "mlir/IR/OwningOpRef.h"
12+
#include "gtest/gtest.h"
13+
14+
using namespace mlir;
15+
16+
namespace {
17+
/// Test fixture for testing operation folders.
18+
class IndexFolderTest : public testing::Test {
19+
public:
20+
IndexFolderTest() { ctx.getOrLoadDialect<index::IndexDialect>(); }
21+
22+
/// Instantiate an operation, invoke its folder, and return the attribute
23+
/// result.
24+
template <typename OpT>
25+
void foldOp(IntegerAttr &value, Type type, ArrayRef<Attribute> operands);
26+
27+
protected:
28+
/// The MLIR context to use.
29+
MLIRContext ctx;
30+
/// A builder to use.
31+
OpBuilder b{&ctx};
32+
};
33+
} // namespace
34+
35+
template <typename OpT>
36+
void IndexFolderTest::foldOp(IntegerAttr &value, Type type,
37+
ArrayRef<Attribute> operands) {
38+
// This function returns null so that `ASSERT_*` works within it.
39+
OperationState state(UnknownLoc::get(&ctx), OpT::getOperationName());
40+
state.addTypes(type);
41+
OwningOpRef<OpT> op = cast<OpT>(b.create(state));
42+
SmallVector<OpFoldResult> results;
43+
LogicalResult result = op->getOperation()->fold(operands, results);
44+
// Propagate the failure to the test.
45+
if (failed(result)) {
46+
value = nullptr;
47+
return;
48+
}
49+
ASSERT_EQ(results.size(), 1u);
50+
value = dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(results.front()));
51+
ASSERT_TRUE(value);
52+
}
53+
54+
TEST_F(IndexFolderTest, TestCastUOpFolder) {
55+
IntegerAttr value;
56+
auto fold = [&](Type type, Attribute input) {
57+
foldOp<index::CastUOp>(value, type, input);
58+
};
59+
60+
// Target width less than or equal to 32 bits.
61+
fold(b.getIntegerType(16), b.getIndexAttr(8000000000));
62+
ASSERT_TRUE(value);
63+
EXPECT_EQ(value.getInt(), 20480u);
64+
65+
// Target width greater than or equal to 64 bits.
66+
fold(b.getIntegerType(64), b.getIndexAttr(2000));
67+
ASSERT_TRUE(value);
68+
EXPECT_EQ(value.getInt(), 2000u);
69+
70+
// Fails to fold, because truncating to 32 bits and then extending creates a
71+
// different value.
72+
fold(b.getIntegerType(64), b.getIndexAttr(8000000000));
73+
EXPECT_FALSE(value);
74+
75+
// Target width between 32 and 64 bits.
76+
fold(b.getIntegerType(40), b.getIndexAttr(0x10000000010000));
77+
// Fold succeeds because the upper bits are truncated in the cast.
78+
ASSERT_TRUE(value);
79+
EXPECT_EQ(value.getInt(), 65536);
80+
81+
// Fails to fold because the upper bits are not truncated.
82+
fold(b.getIntegerType(60), b.getIndexAttr(0x10000000010000));
83+
EXPECT_FALSE(value);
84+
}
85+
86+
TEST_F(IndexFolderTest, TestCastSOpFolder) {
87+
IntegerAttr value;
88+
auto fold = [&](Type type, Attribute input) {
89+
foldOp<index::CastSOp>(value, type, input);
90+
};
91+
92+
// Just test the extension cases to ensure signs are being respected.
93+
94+
// Target width greater than or equal to 64 bits.
95+
fold(b.getIntegerType(64), b.getIndexAttr(-2000));
96+
ASSERT_TRUE(value);
97+
EXPECT_EQ(value.getInt(), -2000);
98+
99+
// Target width between 32 and 64 bits.
100+
fold(b.getIntegerType(40), b.getIndexAttr(-0x10000000010000));
101+
// Fold succeeds because the upper bits are truncated in the cast.
102+
ASSERT_TRUE(value);
103+
EXPECT_EQ(value.getInt(), -65536);
104+
}

0 commit comments

Comments
 (0)