Skip to content

Commit aabfaf9

Browse files
committed
[mlir] Allow empty lists for DenseArrayAttr.
Differential Revision: https://reviews.llvm.org/D129552
1 parent 3c5d631 commit aabfaf9

File tree

5 files changed

+71
-22
lines changed

5 files changed

+71
-22
lines changed

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1875,24 +1875,26 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
18751875
typeElision = AttrTypeElision::Must;
18761876
switch (denseArrayAttr.getElementType()) {
18771877
case DenseArrayBaseAttr::EltType::I8:
1878-
os << "[:i8 ";
1878+
os << "[:i8";
18791879
break;
18801880
case DenseArrayBaseAttr::EltType::I16:
1881-
os << "[:i16 ";
1881+
os << "[:i16";
18821882
break;
18831883
case DenseArrayBaseAttr::EltType::I32:
1884-
os << "[:i32 ";
1884+
os << "[:i32";
18851885
break;
18861886
case DenseArrayBaseAttr::EltType::I64:
1887-
os << "[:i64 ";
1887+
os << "[:i64";
18881888
break;
18891889
case DenseArrayBaseAttr::EltType::F32:
1890-
os << "[:f32 ";
1890+
os << "[:f32";
18911891
break;
18921892
case DenseArrayBaseAttr::EltType::F64:
1893-
os << "[:f64 ";
1893+
os << "[:f64";
18941894
break;
18951895
}
1896+
if (denseArrayAttr.getType().cast<ShapedType>().getRank())
1897+
os << " ";
18961898
denseArrayAttr.printWithoutBraces(os);
18971899
os << "]";
18981900
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {

mlir/lib/IR/BuiltinAttributes.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,9 @@ template <typename T>
838838
Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
839839
if (parser.parseLSquare())
840840
return {};
841+
// Handle empty list case.
842+
if (succeeded(parser.parseOptionalRSquare()))
843+
return get(parser.getContext(), {});
841844
Attribute result = parseWithoutBraces(parser, odsType);
842845
if (parser.parseRSquare())
843846
return {};
@@ -860,42 +863,48 @@ struct denseArrayAttrEltTypeBuilder;
860863
template <>
861864
struct denseArrayAttrEltTypeBuilder<int8_t> {
862865
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
863-
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
866+
static ShapedType getShapedType(MLIRContext *context,
867+
ArrayRef<int64_t> shape) {
864868
return VectorType::get(shape, IntegerType::get(context, 8));
865869
}
866870
};
867871
template <>
868872
struct denseArrayAttrEltTypeBuilder<int16_t> {
869873
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
870-
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
874+
static ShapedType getShapedType(MLIRContext *context,
875+
ArrayRef<int64_t> shape) {
871876
return VectorType::get(shape, IntegerType::get(context, 16));
872877
}
873878
};
874879
template <>
875880
struct denseArrayAttrEltTypeBuilder<int32_t> {
876881
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
877-
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
882+
static ShapedType getShapedType(MLIRContext *context,
883+
ArrayRef<int64_t> shape) {
878884
return VectorType::get(shape, IntegerType::get(context, 32));
879885
}
880886
};
881887
template <>
882888
struct denseArrayAttrEltTypeBuilder<int64_t> {
883889
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
884-
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
890+
static ShapedType getShapedType(MLIRContext *context,
891+
ArrayRef<int64_t> shape) {
885892
return VectorType::get(shape, IntegerType::get(context, 64));
886893
}
887894
};
888895
template <>
889896
struct denseArrayAttrEltTypeBuilder<float> {
890897
constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
891-
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
898+
static ShapedType getShapedType(MLIRContext *context,
899+
ArrayRef<int64_t> shape) {
892900
return VectorType::get(shape, Float32Type::get(context));
893901
}
894902
};
895903
template <>
896904
struct denseArrayAttrEltTypeBuilder<double> {
897905
constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
898-
static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
906+
static ShapedType getShapedType(MLIRContext *context,
907+
ArrayRef<int64_t> shape) {
899908
return VectorType::get(shape, Float64Type::get(context));
900909
}
901910
};
@@ -905,8 +914,9 @@ struct denseArrayAttrEltTypeBuilder<double> {
905914
template <typename T>
906915
DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
907916
ArrayRef<T> content) {
908-
auto shapedType =
909-
denseArrayAttrEltTypeBuilder<T>::getShapedType(context, content.size());
917+
auto size = static_cast<int64_t>(content.size());
918+
auto shapedType = denseArrayAttrEltTypeBuilder<T>::getShapedType(
919+
context, size ? ArrayRef<int64_t>{size} : ArrayRef<int64_t>{});
910920
auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
911921
auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
912922
content.size() * sizeof(T));

mlir/lib/Parser/AttributeParser.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -844,19 +844,34 @@ Attribute Parser::parseDenseArrayAttr() {
844844
return {};
845845
CustomAsmParser parser(*this);
846846
Attribute result;
847+
// Check for empty list.
848+
bool isEmptyList = getToken().is(Token::r_square);
849+
847850
if (auto intType = type.dyn_cast<IntegerType>()) {
848851
switch (type.getIntOrFloatBitWidth()) {
849852
case 8:
850-
result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
853+
if (isEmptyList)
854+
result = DenseI8ArrayAttr::get(parser.getContext(), {});
855+
else
856+
result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
851857
break;
852858
case 16:
853-
result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
859+
if (isEmptyList)
860+
result = DenseI16ArrayAttr::get(parser.getContext(), {});
861+
else
862+
result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
854863
break;
855864
case 32:
856-
result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
865+
if (isEmptyList)
866+
result = DenseI32ArrayAttr::get(parser.getContext(), {});
867+
else
868+
result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
857869
break;
858870
case 64:
859-
result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
871+
if (isEmptyList)
872+
result = DenseI64ArrayAttr::get(parser.getContext(), {});
873+
else
874+
result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
860875
break;
861876
default:
862877
emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
@@ -865,10 +880,16 @@ Attribute Parser::parseDenseArrayAttr() {
865880
} else if (auto floatType = type.dyn_cast<FloatType>()) {
866881
switch (type.getIntOrFloatBitWidth()) {
867882
case 32:
868-
result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
883+
if (isEmptyList)
884+
result = DenseF32ArrayAttr::get(parser.getContext(), {});
885+
else
886+
result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
869887
break;
870888
case 64:
871-
result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
889+
if (isEmptyList)
890+
result = DenseF64ArrayAttr::get(parser.getContext(), {});
891+
else
892+
result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
872893
break;
873894
default:
874895
emitError(typeLoc, "expected f32 or f64 but got: ") << type;

mlir/test/IR/attribute.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,19 @@ func.func @simple_scalar_example() {
521521
//===----------------------------------------------------------------------===//
522522

523523
// CHECK-LABEL: func @dense_array_attr
524-
func.func @dense_array_attr() attributes{
524+
func.func @dense_array_attr() attributes{
525+
// CHECK-SAME: emptyf32attr = [:f32],
526+
emptyf32attr = [:f32],
527+
// CHECK-SAME: emptyf64attr = [:f64],
528+
emptyf64attr = [:f64],
529+
// CHECK-SAME: emptyi16attr = [:i16],
530+
emptyi16attr = [:i16],
531+
// CHECK-SAME: emptyi32attr = [:i32],
532+
emptyi32attr = [:i32],
533+
// CHECK-SAME: emptyi64attr = [:i64],
534+
emptyi64attr = [:i64],
535+
// CHECK-SAME: emptyi8attr = [:i8],
536+
emptyi8attr = [:i8],
525537
// CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03],
526538
f32attr = [:f32 1024., 453., -6435.],
527539
// CHECK-SAME: f64attr = [:f64 -1.420000e+02],
@@ -549,6 +561,8 @@ func.func @dense_array_attr() attributes{
549561
f32attr = [1024., 453., -6435.]
550562
// CHECK-SAME: f64attr = [-1.420000e+02]
551563
f64attr = [-142.]
564+
// CHECK-SAME: emptyattr = []
565+
emptyattr = []
552566
return
553567
}
554568

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,13 @@ def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
277277
DenseI32ArrayAttr:$i32attr,
278278
DenseI64ArrayAttr:$i64attr,
279279
DenseF32ArrayAttr:$f32attr,
280-
DenseF64ArrayAttr:$f64attr
280+
DenseF64ArrayAttr:$f64attr,
281+
DenseI32ArrayAttr:$emptyattr
281282
);
282283
let assemblyFormat = [{
283284
`i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
284285
`i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr
286+
`emptyattr` `=` $emptyattr
285287
attr-dict
286288
}];
287289
}

0 commit comments

Comments
 (0)