-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] #142797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
0b58685
to
e215e22
Compare
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesThis PR adds support for Elementwise operations' (unary & binary) lowering from Workgroup to Subgroup. Patch is 87.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142797.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..e1687031d259a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -8,15 +8,18 @@
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <optional>
namespace mlir {
namespace xegpu {
@@ -314,6 +317,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+// This pattern transforms elementwise ops (unary/binary) in math/arith dialect
+template <typename Op>
+struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
+ using OpConversionPattern<Op>::OpConversionPattern;
+ using OneToNOpAdaptor = typename OpConversionPattern<Op>::OneToNOpAdaptor;
+
+ LogicalResult
+ matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // All operands/results must be 1D or 2D vectors
+ auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType || (resultType.getRank() != 1 && resultType.getRank() != 2))
+ return rewriter.notifyMatchFailure(
+ op, "Result type is not a 1D or 2D vector");
+
+ ArrayRef<int64_t> shape = resultType.getShape();
+ for (Value operand : op->getOperands()) {
+ auto operandType = dyn_cast<VectorType>(operand.getType());
+ if (!operandType || operandType.getRank() != resultType.getRank() ||
+ operandType.getShape() != shape) {
+ return rewriter.notifyMatchFailure(
+ op, "Operand type is not a 1D or 2D vector with the same shape as "
+ "result type");
+ }
+ }
+
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+ if (!layout || !layout.getSgLayout())
+ return rewriter.notifyMatchFailure(
+ op, "Operation does not have a valid layout attribute for subgroup "
+ "distribution");
+
+ // Extract sgShape from layout
+ SmallVector<int64_t> sgShape;
+ if (auto sgDataAttr = layout.getSgData()) {
+ sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+ } else {
+ auto sgLayoutArr = layout.getSgLayout();
+ sgShape.reserve(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i) {
+ assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
+ sgShape.push_back(shape[i] / sgLayoutArr[i]);
+ }
+ }
+
+ size_t numVariants = adaptor.getOperands().empty()
+ ? 0
+ : adaptor.getOperands().front().size();
+ for (auto &operandVec : adaptor.getOperands())
+ if (operandVec.size() != numVariants)
+ return rewriter.notifyMatchFailure(
+ op, "Operand lists have mismatched sizes");
+
+ SmallVector<Value> newResults;
+
+ auto origResultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ VectorType newResultType =
+ origResultType
+ ? VectorType::get(sgShape, origResultType.getElementType())
+ : VectorType::get(sgShape, resultType.getElementType());
+
+ for (size_t i = 0; i < numVariants; ++i) {
+ SmallVector<Value> operands;
+ for (auto &operandVec : adaptor.getOperands())
+ operands.push_back(operandVec[i]);
+
+ auto newOp = rewriter.create<Op>(op.getLoc(), newResultType, operands);
+
+ // Copy all attributes except "layout", and add "layout_result_0" with
+ // sgLayout/data dropped
+ for (auto attr : op->getAttrs()) {
+ if (attr.getName() != "layout")
+ newOp->setAttr(attr.getName(), attr.getValue());
+ }
+ newOp->setAttr("layout_result_0", layout.dropSgLayoutAndData());
+
+ newResults.push_back(newOp.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newResults});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
@@ -322,6 +409,57 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
patterns.getContext());
+ // Add elementwise operations that can be distributed to subgroups
+ patterns.add<
+ WgToSgElementwiseOp<arith::AddFOp>, WgToSgElementwiseOp<arith::SubFOp>,
+ WgToSgElementwiseOp<math::ExpOp>, WgToSgElementwiseOp<math::SqrtOp>,
+ WgToSgElementwiseOp<math::AbsFOp>, WgToSgElementwiseOp<math::CosOp>,
+ WgToSgElementwiseOp<math::CoshOp>, WgToSgElementwiseOp<math::AcosOp>,
+ WgToSgElementwiseOp<math::AcoshOp>, WgToSgElementwiseOp<math::SinOp>,
+ WgToSgElementwiseOp<math::SinhOp>, WgToSgElementwiseOp<math::AsinOp>,
+ WgToSgElementwiseOp<math::AsinhOp>, WgToSgElementwiseOp<math::TanOp>,
+ WgToSgElementwiseOp<math::TanhOp>, WgToSgElementwiseOp<math::AtanOp>,
+ WgToSgElementwiseOp<math::Atan2Op>, WgToSgElementwiseOp<math::AtanhOp>,
+ WgToSgElementwiseOp<math::ErfOp>, WgToSgElementwiseOp<math::LogOp>,
+ WgToSgElementwiseOp<math::Log2Op>, WgToSgElementwiseOp<math::FloorOp>,
+ WgToSgElementwiseOp<math::CeilOp>, WgToSgElementwiseOp<math::PowFOp>,
+ WgToSgElementwiseOp<math::RsqrtOp>, WgToSgElementwiseOp<arith::NegFOp>,
+ WgToSgElementwiseOp<arith::AddIOp>, WgToSgElementwiseOp<arith::SubIOp>,
+ WgToSgElementwiseOp<arith::MulFOp>, WgToSgElementwiseOp<arith::MulIOp>,
+ WgToSgElementwiseOp<arith::ShLIOp>, WgToSgElementwiseOp<arith::ShRSIOp>,
+ WgToSgElementwiseOp<arith::ShRUIOp>, WgToSgElementwiseOp<arith::DivFOp>,
+ WgToSgElementwiseOp<arith::DivSIOp>, WgToSgElementwiseOp<arith::DivUIOp>,
+ WgToSgElementwiseOp<arith::MaximumFOp>,
+ WgToSgElementwiseOp<arith::MinimumFOp>,
+ WgToSgElementwiseOp<arith::RemSIOp>, WgToSgElementwiseOp<arith::RemUIOp>,
+ WgToSgElementwiseOp<arith::TruncFOp>,
+ WgToSgElementwiseOp<arith::TruncIOp>, WgToSgElementwiseOp<arith::ExtFOp>,
+ WgToSgElementwiseOp<arith::ExtSIOp>, WgToSgElementwiseOp<arith::ExtUIOp>,
+ WgToSgElementwiseOp<arith::SIToFPOp>,
+ WgToSgElementwiseOp<arith::UIToFPOp>,
+ WgToSgElementwiseOp<arith::FPToSIOp>,
+ WgToSgElementwiseOp<arith::FPToUIOp>,
+ WgToSgElementwiseOp<arith::IndexCastUIOp>,
+ WgToSgElementwiseOp<arith::IndexCastOp>,
+ WgToSgElementwiseOp<arith::BitcastOp>, WgToSgElementwiseOp<arith::CmpIOp>,
+ WgToSgElementwiseOp<arith::CmpFOp>, WgToSgElementwiseOp<arith::AndIOp>,
+ WgToSgElementwiseOp<arith::CeilDivSIOp>,
+ WgToSgElementwiseOp<arith::CeilDivUIOp>,
+ WgToSgElementwiseOp<arith::FloorDivSIOp>,
+ WgToSgElementwiseOp<arith::MaxNumFOp>,
+ WgToSgElementwiseOp<arith::MaxSIOp>, WgToSgElementwiseOp<arith::MaxUIOp>,
+ WgToSgElementwiseOp<arith::MinNumFOp>,
+ WgToSgElementwiseOp<arith::MinSIOp>, WgToSgElementwiseOp<arith::MinUIOp>,
+ WgToSgElementwiseOp<arith::OrIOp>, WgToSgElementwiseOp<arith::RemFOp>,
+ WgToSgElementwiseOp<arith::SelectOp>, WgToSgElementwiseOp<arith::XOrIOp>,
+ WgToSgElementwiseOp<math::AbsIOp>, WgToSgElementwiseOp<math::CbrtOp>,
+ WgToSgElementwiseOp<math::CopySignOp>, WgToSgElementwiseOp<math::CtPopOp>,
+ WgToSgElementwiseOp<math::ErfcOp>, WgToSgElementwiseOp<math::Exp2Op>,
+ WgToSgElementwiseOp<math::ExpM1Op>, WgToSgElementwiseOp<math::FPowIOp>,
+ WgToSgElementwiseOp<math::IPowIOp>, WgToSgElementwiseOp<math::Log10Op>,
+ WgToSgElementwiseOp<math::Log1pOp>, WgToSgElementwiseOp<math::RoundOp>,
+ WgToSgElementwiseOp<math::RoundEvenOp>,
+ WgToSgElementwiseOp<math::TruncOp>>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -368,6 +506,31 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
return isLegal(layout);
});
+ target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
+ [=](Operation *op) -> std::optional<bool> {
+ // Handle unary and binary operations
+ if (op->getNumOperands() < 1 || op->getNumOperands() > 2)
+ return true;
+
+ // check if input and output are vectors
+ VectorType resultType =
+ dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType || resultType.getRank() != 2)
+ return true;
+
+ // Check if all operands are vectors
+ for (Value operand : op->getOperands()) {
+ VectorType operandType = dyn_cast<VectorType>(operand.getType());
+ if (!operandType || operandType.getRank() != 2 ||
+ operandType.getShape() != resultType.getShape()) {
+ return true;
+ }
+ }
+
+ auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
+ op->getAttrOfType<xegpu::LayoutAttr>("layout"));
+ return isLegal(layout);
+ });
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
new file mode 100644
index 0000000000000..85767f4f2bd67
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -0,0 +1,1048 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @elementwise_ops {
+ // CHECK-LABEL: elemwise_ops
+ gpu.func @elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+ %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+ -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+ -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+
+ %load_a = xegpu.load_nd %tdesc_a
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_b = xegpu.load_nd %tdesc_b
+ : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xf32>
+ %load_c = xegpu.load_nd %tdesc_c
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+ %load_d = xegpu.load_nd %tdesc_d
+ : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+ -> vector<24x32xi32>
+
+ // Floating point ops
+ // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.subf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.exp {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.sqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.absf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.cos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.cosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.acos {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.acosh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.sin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.sinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.asin {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.asinh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.tan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.tanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.atan {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.atan2 {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.atanh {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.erf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.log {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.log2 {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.floor {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.ceil {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: math.rsqrt {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.mulf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.divf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.maximumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.minimumf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+ // CHECK: arith.addi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.subi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.muli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.shli {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.shrsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.shrui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.divsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.divui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.remsi {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ // CHECK: arith.remui {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+ %addf = arith.addf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %subf = arith.subf %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %exp = math.exp %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sqrt = math.sqrt %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %absf = math.absf %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cos = math.cos %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %cosh = math.cosh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %acos = math.acos %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %acosh = math.acosh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sin = math.sin %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %sinh = math.sinh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %asin = math.asin %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %asinh = math.asinh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %tan = math.tan %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %tanh = math.tanh %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atan = math.atan %load_a
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atan2 = math.atan2 %load_a, %load_b
+ {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+ : vector<24x32xf32>
+ %atanh = math.a...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, looks good to me with some comments.
// Only match ops with elementwise trait | ||
if (!OpTrait::hasElementwiseMappableTraits(op)) | ||
return rewriter.notifyMatchFailure(op, "Not an elementwise op"); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worth to have check that the number of results is 1.
state.addAttribute(attr.getName(), attr.getValue()); | ||
} | ||
Operation *newOp = rewriter.create(state); | ||
xegpu::setLayoutAttr(newOp->getResult(0), layout.dropSgLayoutAndData()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can move this into the loop, so the code reads more compact. (all attributes are handled at one place)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the attribute is added to the newOp (transformed op) ..and the loop is traversing over the attributes not numvariants
return true; | ||
|
||
// Check if all operands are vectors of the same shape | ||
for (Value operand : op->getOperands()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consider the use of llvm::all_equal
on op->getOperandTypes()
#include "mlir/Dialect/Arith/Utils/Utils.h" | ||
#include "mlir/Dialect/GPU/IR/GPUDialect.h" | ||
#include "mlir/Dialect/Index/IR/IndexDialect.h" | ||
#include "mlir/Dialect/Index/IR/IndexOps.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these two headers still needed after using trait?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes for the legality check
This PR adds support for Elementwise operations' (unary & binary) lowering from Workgroup to Subgroup.