Skip to content

[MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] #142797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Jun 4, 2025

This PR adds support for Elementwise operations' (unary & binary) lowering from Workgroup to Subgroup.

@nbpatel nbpatel force-pushed the xegpu_wg_sg_elementwise branch from 0b58685 to e215e22 Compare June 6, 2025 15:50
@nbpatel nbpatel changed the title [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [MLIR][XeGPU] Add support for elementwise ops in Wg to Sg distribute pass [1/N] Jun 6, 2025
@nbpatel nbpatel marked this pull request as ready for review June 6, 2025 17:24
@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds support for Elementwise operations' (unary & binary) lowering from Workgroup to Subgroup.


Patch is 87.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142797.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+163)
  • (added) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir (+1048)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..e1687031d259a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -8,15 +8,18 @@
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include <optional>
 
 namespace mlir {
 namespace xegpu {
@@ -314,6 +317,90 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+// This pattern transforms elementwise ops (unary/binary) in math/arith dialect
+template <typename Op>
+struct WgToSgElementwiseOp : public OpConversionPattern<Op> {
+  using OpConversionPattern<Op>::OpConversionPattern;
+  using OneToNOpAdaptor = typename OpConversionPattern<Op>::OneToNOpAdaptor;
+
+  LogicalResult
+  matchAndRewrite(Op op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // All operands/results must be 1D or 2D vectors
+    auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType || (resultType.getRank() != 1 && resultType.getRank() != 2))
+      return rewriter.notifyMatchFailure(
+          op, "Result type is not a 1D or 2D vector");
+
+    ArrayRef<int64_t> shape = resultType.getShape();
+    for (Value operand : op->getOperands()) {
+      auto operandType = dyn_cast<VectorType>(operand.getType());
+      if (!operandType || operandType.getRank() != resultType.getRank() ||
+          operandType.getShape() != shape) {
+        return rewriter.notifyMatchFailure(
+            op, "Operand type is not a 1D or 2D vector with the same shape as "
+                "result type");
+      }
+    }
+
+    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    if (!layout || !layout.getSgLayout())
+      return rewriter.notifyMatchFailure(
+          op, "Operation does not have a valid layout attribute for subgroup "
+              "distribution");
+
+    // Extract sgShape from layout
+    SmallVector<int64_t> sgShape;
+    if (auto sgDataAttr = layout.getSgData()) {
+      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+    } else {
+      auto sgLayoutArr = layout.getSgLayout();
+      sgShape.reserve(shape.size());
+      for (size_t i = 0; i < shape.size(); ++i) {
+        assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
+        sgShape.push_back(shape[i] / sgLayoutArr[i]);
+      }
+    }
+
+    size_t numVariants = adaptor.getOperands().empty()
+                             ? 0
+                             : adaptor.getOperands().front().size();
+    for (auto &operandVec : adaptor.getOperands())
+      if (operandVec.size() != numVariants)
+        return rewriter.notifyMatchFailure(
+            op, "Operand lists have mismatched sizes");
+
+    SmallVector<Value> newResults;
+
+    auto origResultType = dyn_cast<VectorType>(op->getResult(0).getType());
+    VectorType newResultType =
+        origResultType
+            ? VectorType::get(sgShape, origResultType.getElementType())
+            : VectorType::get(sgShape, resultType.getElementType());
+
+    for (size_t i = 0; i < numVariants; ++i) {
+      SmallVector<Value> operands;
+      for (auto &operandVec : adaptor.getOperands())
+        operands.push_back(operandVec[i]);
+
+      auto newOp = rewriter.create<Op>(op.getLoc(), newResultType, operands);
+
+      // Copy all attributes except "layout", and add "layout_result_0" with
+      // sgLayout/data dropped
+      for (auto attr : op->getAttrs()) {
+        if (attr.getName() != "layout")
+          newOp->setAttr(attr.getName(), attr.getValue());
+      }
+      newOp->setAttr("layout_result_0", layout.dropSgLayoutAndData());
+
+      newResults.push_back(newOp.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newResults});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -322,6 +409,57 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
       patterns.getContext());
+  // Add elementwise operations that can be distributed to subgroups
+  patterns.add<
+      WgToSgElementwiseOp<arith::AddFOp>, WgToSgElementwiseOp<arith::SubFOp>,
+      WgToSgElementwiseOp<math::ExpOp>, WgToSgElementwiseOp<math::SqrtOp>,
+      WgToSgElementwiseOp<math::AbsFOp>, WgToSgElementwiseOp<math::CosOp>,
+      WgToSgElementwiseOp<math::CoshOp>, WgToSgElementwiseOp<math::AcosOp>,
+      WgToSgElementwiseOp<math::AcoshOp>, WgToSgElementwiseOp<math::SinOp>,
+      WgToSgElementwiseOp<math::SinhOp>, WgToSgElementwiseOp<math::AsinOp>,
+      WgToSgElementwiseOp<math::AsinhOp>, WgToSgElementwiseOp<math::TanOp>,
+      WgToSgElementwiseOp<math::TanhOp>, WgToSgElementwiseOp<math::AtanOp>,
+      WgToSgElementwiseOp<math::Atan2Op>, WgToSgElementwiseOp<math::AtanhOp>,
+      WgToSgElementwiseOp<math::ErfOp>, WgToSgElementwiseOp<math::LogOp>,
+      WgToSgElementwiseOp<math::Log2Op>, WgToSgElementwiseOp<math::FloorOp>,
+      WgToSgElementwiseOp<math::CeilOp>, WgToSgElementwiseOp<math::PowFOp>,
+      WgToSgElementwiseOp<math::RsqrtOp>, WgToSgElementwiseOp<arith::NegFOp>,
+      WgToSgElementwiseOp<arith::AddIOp>, WgToSgElementwiseOp<arith::SubIOp>,
+      WgToSgElementwiseOp<arith::MulFOp>, WgToSgElementwiseOp<arith::MulIOp>,
+      WgToSgElementwiseOp<arith::ShLIOp>, WgToSgElementwiseOp<arith::ShRSIOp>,
+      WgToSgElementwiseOp<arith::ShRUIOp>, WgToSgElementwiseOp<arith::DivFOp>,
+      WgToSgElementwiseOp<arith::DivSIOp>, WgToSgElementwiseOp<arith::DivUIOp>,
+      WgToSgElementwiseOp<arith::MaximumFOp>,
+      WgToSgElementwiseOp<arith::MinimumFOp>,
+      WgToSgElementwiseOp<arith::RemSIOp>, WgToSgElementwiseOp<arith::RemUIOp>,
+      WgToSgElementwiseOp<arith::TruncFOp>,
+      WgToSgElementwiseOp<arith::TruncIOp>, WgToSgElementwiseOp<arith::ExtFOp>,
+      WgToSgElementwiseOp<arith::ExtSIOp>, WgToSgElementwiseOp<arith::ExtUIOp>,
+      WgToSgElementwiseOp<arith::SIToFPOp>,
+      WgToSgElementwiseOp<arith::UIToFPOp>,
+      WgToSgElementwiseOp<arith::FPToSIOp>,
+      WgToSgElementwiseOp<arith::FPToUIOp>,
+      WgToSgElementwiseOp<arith::IndexCastUIOp>,
+      WgToSgElementwiseOp<arith::IndexCastOp>,
+      WgToSgElementwiseOp<arith::BitcastOp>, WgToSgElementwiseOp<arith::CmpIOp>,
+      WgToSgElementwiseOp<arith::CmpFOp>, WgToSgElementwiseOp<arith::AndIOp>,
+      WgToSgElementwiseOp<arith::CeilDivSIOp>,
+      WgToSgElementwiseOp<arith::CeilDivUIOp>,
+      WgToSgElementwiseOp<arith::FloorDivSIOp>,
+      WgToSgElementwiseOp<arith::MaxNumFOp>,
+      WgToSgElementwiseOp<arith::MaxSIOp>, WgToSgElementwiseOp<arith::MaxUIOp>,
+      WgToSgElementwiseOp<arith::MinNumFOp>,
+      WgToSgElementwiseOp<arith::MinSIOp>, WgToSgElementwiseOp<arith::MinUIOp>,
+      WgToSgElementwiseOp<arith::OrIOp>, WgToSgElementwiseOp<arith::RemFOp>,
+      WgToSgElementwiseOp<arith::SelectOp>, WgToSgElementwiseOp<arith::XOrIOp>,
+      WgToSgElementwiseOp<math::AbsIOp>, WgToSgElementwiseOp<math::CbrtOp>,
+      WgToSgElementwiseOp<math::CopySignOp>, WgToSgElementwiseOp<math::CtPopOp>,
+      WgToSgElementwiseOp<math::ErfcOp>, WgToSgElementwiseOp<math::Exp2Op>,
+      WgToSgElementwiseOp<math::ExpM1Op>, WgToSgElementwiseOp<math::FPowIOp>,
+      WgToSgElementwiseOp<math::IPowIOp>, WgToSgElementwiseOp<math::Log10Op>,
+      WgToSgElementwiseOp<math::Log1pOp>, WgToSgElementwiseOp<math::RoundOp>,
+      WgToSgElementwiseOp<math::RoundEvenOp>,
+      WgToSgElementwiseOp<math::TruncOp>>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -368,6 +506,31 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
     return isLegal(layout);
   });
+  target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
+      [=](Operation *op) -> std::optional<bool> {
+        // Handle unary and binary operations
+        if (op->getNumOperands() < 1 || op->getNumOperands() > 2)
+          return true;
+
+        // check if input and output are vectors
+        VectorType resultType =
+            dyn_cast<VectorType>(op->getResult(0).getType());
+        if (!resultType || resultType.getRank() != 2)
+          return true;
+
+        // Check if all operands are vectors
+        for (Value operand : op->getOperands()) {
+          VectorType operandType = dyn_cast<VectorType>(operand.getType());
+          if (!operandType || operandType.getRank() != 2 ||
+              operandType.getShape() != resultType.getShape()) {
+            return true;
+          }
+        }
+
+        auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
+            op->getAttrOfType<xegpu::LayoutAttr>("layout"));
+        return isLegal(layout);
+      });
 
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
new file mode 100644
index 0000000000000..85767f4f2bd67
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -0,0 +1,1048 @@
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+
+gpu.module @elementwise_ops {
+   // CHECK-LABEL: elemwise_ops
+   gpu.func @elemwise_ops(%a: memref<24x32xf32>, %b: memref<24x32xf32>, %c: memref<24x32xi32>, %d: memref<24x32xi32>) {
+        %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+          -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+        %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<24x32xf32>
+          -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+        %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x32xi32>
+          -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+        %tdesc_d = xegpu.create_nd_tdesc %d[0, 0] : memref<24x32xi32>
+          -> !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+
+        %load_a = xegpu.load_nd %tdesc_a
+          : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+          -> vector<24x32xf32>
+        %load_b = xegpu.load_nd %tdesc_b
+          : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+          -> vector<24x32xf32>
+        %load_c = xegpu.load_nd %tdesc_c
+          : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+          -> vector<24x32xi32>
+        %load_d = xegpu.load_nd %tdesc_d
+          : !xegpu.tensor_desc<24x32xi32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+          -> vector<24x32xi32>
+
+        // Floating point ops
+        // CHECK: arith.addf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: arith.subf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.exp {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.sqrt {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.absf {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.cos {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.cosh {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.acos {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.acosh {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.sin {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.sinh {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.asin {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.asinh {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.tan {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.tanh {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.atan {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.atan2 {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.atanh {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.erf {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.log {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.log2 {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.floor {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.ceil {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.powf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: math.rsqrt {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: arith.negf {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: arith.mulf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: arith.divf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: arith.maximumf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: arith.minimumf {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xf32>
+        // CHECK: arith.addi {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        // CHECK: arith.subi {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        // CHECK: arith.muli {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        // CHECK: arith.shli {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        // CHECK: arith.shrsi {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        // CHECK: arith.shrui {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        // CHECK: arith.divsi {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        // CHECK: arith.divui {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        // CHECK: arith.remsi {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        // CHECK: arith.remui {{.*}}, {{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>} : vector<12x8xi32>
+        %addf = arith.addf %load_a, %load_b
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %subf = arith.subf %load_a, %load_b
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %exp = math.exp %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %sqrt = math.sqrt %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %absf = math.absf %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %cos = math.cos %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %cosh = math.cosh %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %acos = math.acos %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %acosh = math.acosh %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %sin = math.sin %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %sinh = math.sinh %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %asin = math.asin %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %asinh = math.asinh %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %tan = math.tan %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %tanh = math.tanh %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %atan = math.atan %load_a
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %atan2 = math.atan2 %load_a, %load_b
+          {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>}
+          : vector<24x32xf32>
+        %atanh = math.a...
[truncated]

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, looks good to me with some comments.

// Only match ops with elementwise trait
if (!OpTrait::hasElementwiseMappableTraits(op))
return rewriter.notifyMatchFailure(op, "Not an elementwise op");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth to have check that the number of results is 1.

state.addAttribute(attr.getName(), attr.getValue());
}
Operation *newOp = rewriter.create(state);
xegpu::setLayoutAttr(newOp->getResult(0), layout.dropSgLayoutAndData());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can move this into the loop, so the code reads more compact. (all attributes are handled at one place)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attribute is added to the newOp (transformed op) ..and the loop is traversing over the attributes not numvariants

return true;

// Check if all operands are vectors of the same shape
for (Value operand : op->getOperands()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider the use of llvm::all_equal on op->getOperandTypes()

#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these two headers still needed after using trait?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes for the legality check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants