Skip to content

Skip test_moe_matmul_ogs on older cards #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2025
Merged

Skip test_moe_matmul_ogs on older cards #121

merged 1 commit into from
Jun 2, 2025

Conversation

jansel
Copy link
Contributor

@jansel jansel commented Jun 2, 2025

@jansel jansel force-pushed the jansel/stack/18 branch from 3115036 to 799834c Compare June 2, 2025 01:16
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 2, 2025
@jansel
Copy link
Contributor Author

jansel commented Jun 2, 2025

cc @yf225 @manman-ren

On my local machine this is failing with the following. Possibly a Triton bug?

test/test_examples.py::TestExamples::test_moe_matmul_ogs python: /home/runner/work/triton/triton/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::detail::TypedValue<mlir::IntegerType>, From = mlir::Value]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
"builtin.module"() ({
  "tt.func"() <{arg_attrs = [{tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {tt.divisibility = 16 : i32}, {}, {}, {tt.divisibility = 16 : i32}, {}, {}, {}, {}], function_type = (!tt.ptr<i32>, !tt.ptr<i32>, !tt.ptr<i32>, !tt.ptr<f16>, !tt.ptr<f16>, !tt.ptr<f16>, i32, i32, i32, i32, i32, i32, i32) -> (), sym_name = "_moe_matmul_ogs_kernel", sym_visibility = "public"}> ({
  ^bb0(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>, %arg2: !tt.ptr<i32>, %arg3: !tt.ptr<f16>, %arg4: !tt.ptr<f16>, %arg5: !tt.ptr<f16>, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32):
    %3 = "tt.get_program_id"() <{axis = 0 : i32}> : () -> i32
    %4 = "tt.call"() <{callee = @"zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int32_"}> : () -> tensor<1xi32>
    %5 = "tt.splat"(%3) : (i32) -> tensor<1xi32>
    %6 = "arith.extsi"(%5) : (tensor<1xi32>) -> tensor<1xi64>
    %7 = "arith.extsi"(%4) : (tensor<1xi32>) -> tensor<1xi64>
    %8 = "arith.addi"(%6, %7) <{overflowFlags = #arith.overflow<none>}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
    %9 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
    %10 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
    %11 = "arith.constant"() <{value = dense<2147483647> : tensor<1xi64>}> : () -> tensor<1xi64>
    %12 = "arith.cmpi"(%8, %11) <{predicate = 3 : i64}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
    %13 = "arith.constant"() <{value = dense<-2147483648> : tensor<1xi64>}> : () -> tensor<1xi64>
    %14 = "arith.cmpi"(%8, %13) <{predicate = 5 : i64}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
    %15 = "arith.andi"(%12, %14) : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1>
    %16 = "arith.addi"(%5, %4) <{overflowFlags = #arith.overflow<none>}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
    %17 = "arith.constant"() <{value = 1 : i32}> : () -> i32
    %18 = "arith.constant"() <{value = 1 : i32}> : () -> i32
    %19 = "arith.constant"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
    %20 = "arith.extsi"(%16) : (tensor<1xi32>) -> tensor<1xi64>
    %21 = "arith.extsi"(%19) : (tensor<1xi32>) -> tensor<1xi64>
    %22 = "arith.muli"(%20, %21) <{overflowFlags = #arith.overflow<none>}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
    %23 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
    %24 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
    %25 = "arith.constant"() <{value = dense<2147483647> : tensor<1xi64>}> : () -> tensor<1xi64>
    %26 = "arith.cmpi"(%22, %25) <{predicate = 3 : i64}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
    %27 = "arith.constant"() <{value = dense<-2147483648> : tensor<1xi64>}> : () -> tensor<1xi64>
    %28 = "arith.cmpi"(%22, %27) <{predicate = 5 : i64}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
    %29 = "arith.andi"(%26, %28) : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1>
    %30 = "arith.muli"(%16, %19) <{overflowFlags = #arith.overflow<none>}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
    %31 = "tt.splat"(%arg0) : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>>
    %32 = "tt.addptr"(%31, %30) : (tensor<1x!tt.ptr<i32>>, tensor<1xi32>) -> tensor<1x!tt.ptr<i32>>
    %33 = "tt.load"(%32) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>}> : (tensor<1x!tt.ptr<i32>>) -> tensor<1xi32>
    %34 = "arith.constant"() <{value = 1 : i32}> : () -> i32
    %35 = "arith.constant"() <{value = 1 : i32}> : () -> i32
    %36 = "arith.constant"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
    %37 = "arith.extsi"(%16) : (tensor<1xi32>) -> tensor<1xi64>
    %38 = "arith.extsi"(%36) : (tensor<1xi32>) -> tensor<1xi64>
    %39 = "arith.muli"(%37, %38) <{overflowFlags = #arith.overflow<none>}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
    %40 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
    %41 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
    %42 = "arith.constant"() <{value = dense<2147483647> : tensor<1xi64>}> : () -> tensor<1xi64>
    %43 = "arith.cmpi"(%39, %42) <{predicate = 3 : i64}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
    %44 = "arith.constant"() <{value = dense<-2147483648> : tensor<1xi64>}> : () -> tensor<1xi64>
    %45 = "arith.cmpi"(%39, %44) <{predicate = 5 : i64}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
    %46 = "arith.andi"(%43, %45) : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1>
    %47 = "arith.muli"(%16, %36) <{overflowFlags = #arith.overflow<none>}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
    %48 = "tt.splat"(%arg1) : (!tt.ptr<i32>) -> tensor<1x!tt.ptr<i32>>
    %49 = "tt.addptr"(%48, %47) : (tensor<1x!tt.ptr<i32>>, tensor<1xi32>) -> tensor<1x!tt.ptr<i32>>
    %50 = "tt.load"(%49) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 0, 0>}> : (tensor<1x!tt.ptr<i32>>) -> tensor<1xi32>
    %51 = "arith.constant"() <{value = 0 : i32}> : () -> i32
    %52 = "arith.constant"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
    %53 = "arith.cmpi"(%50, %52) <{predicate = 1 : i64}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    "scf.if"(%53) ({
      %54 = "arith.constant"() <{value = 0 : i32}> : () -> i32
      %55 = "arith.constant"() <{value = 16 : i32}> : () -> i32
      %56 = "arith.bitcast"(%54) : (i32) -> i32
      %57 = "arith.bitcast"(%arg10) : (i32) -> i32
      %58 = "arith.bitcast"(%55) : (i32) -> i32
      %59 = "ub.poison"() <{value = #ub.poison}> : () -> i32
      "scf.for"(%56, %57, %58) ({
      ^bb0(%arg13: i32):
        %60 = "tt.make_range"() <{end = 16 : i32, start = 0 : i32}> : () -> tensor<16xi32>
        %61 = "tt.splat"(%arg13) : (i32) -> tensor<16xi32>
        %62 = "arith.extsi"(%61) : (tensor<16xi32>) -> tensor<16xi64>
        %63 = "arith.extsi"(%60) : (tensor<16xi32>) -> tensor<16xi64>
        %64 = "arith.addi"(%62, %63) <{overflowFlags = #arith.overflow<none>}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi64>
        %65 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
        %66 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
        %67 = "arith.constant"() <{value = dense<2147483647> : tensor<16xi64>}> : () -> tensor<16xi64>
        %68 = "arith.cmpi"(%64, %67) <{predicate = 3 : i64}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi1>
        %69 = "arith.constant"() <{value = dense<-2147483648> : tensor<16xi64>}> : () -> tensor<16xi64>
        %70 = "arith.cmpi"(%64, %69) <{predicate = 5 : i64}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi1>
        %71 = "arith.andi"(%68, %70) : (tensor<16xi1>, tensor<16xi1>) -> tensor<16xi1>
        %72 = "arith.addi"(%61, %60) <{overflowFlags = #arith.overflow<none>}> : (tensor<16xi32>, tensor<16xi32>) -> tensor<16xi32>
        %73 = "tt.splat"(%arg10) : (i32) -> tensor<16xi32>
        %74 = "arith.cmpi"(%72, %73) <{predicate = 2 : i64}> : (tensor<16xi32>, tensor<16xi32>) -> tensor<16xi1>
        %75 = "arith.constant"() <{value = 0 : i32}> : () -> i32
        %76 = "arith.constant"() <{value = 16 : i32}> : () -> i32
        %77 = "arith.bitcast"(%75) : (i32) -> i32
        %78 = "arith.bitcast"(%arg11) : (i32) -> i32
        %79 = "arith.bitcast"(%76) : (i32) -> i32
        %80 = "ub.poison"() <{value = #ub.poison}> : () -> i32
        "scf.for"(%77, %78, %79) ({
        ^bb0(%arg14: i32):
          %81 = "tt.make_range"() <{end = 16 : i32, start = 0 : i32}> : () -> tensor<16xi32>
          %82 = "tt.splat"(%arg14) : (i32) -> tensor<16xi32>
          %83 = "arith.extsi"(%82) : (tensor<16xi32>) -> tensor<16xi64>
          %84 = "arith.extsi"(%81) : (tensor<16xi32>) -> tensor<16xi64>
          %85 = "arith.addi"(%83, %84) <{overflowFlags = #arith.overflow<none>}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi64>
          %86 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
          %87 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
          %88 = "arith.constant"() <{value = dense<2147483647> : tensor<16xi64>}> : () -> tensor<16xi64>
          %89 = "arith.cmpi"(%85, %88) <{predicate = 3 : i64}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi1>
          %90 = "arith.constant"() <{value = dense<-2147483648> : tensor<16xi64>}> : () -> tensor<16xi64>
          %91 = "arith.cmpi"(%85, %90) <{predicate = 5 : i64}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi1>
          %92 = "arith.andi"(%89, %91) : (tensor<16xi1>, tensor<16xi1>) -> tensor<16xi1>
          %93 = "arith.addi"(%82, %81) <{overflowFlags = #arith.overflow<none>}> : (tensor<16xi32>, tensor<16xi32>) -> tensor<16xi32>
          %94 = "tt.splat"(%arg11) : (i32) -> tensor<16xi32>
          %95 = "arith.cmpi"(%93, %94) <{predicate = 2 : i64}> : (tensor<16xi32>, tensor<16xi32>) -> tensor<16xi1>
          %96 = "tt.expand_dims"(%50) <{axis = 0 : i32}> : (tensor<1xi32>) -> tensor<1x1xi32>
          %97 = "tt.expand_dims"(%72) <{axis = 0 : i32}> : (tensor<16xi32>) -> tensor<1x16xi32>
          %98 = "tt.broadcast"(%96) : (tensor<1x1xi32>) -> tensor<1x16xi32>
          %99 = "arith.cmpi"(%97, %98) <{predicate = 2 : i64}> : (tensor<1x16xi32>, tensor<1x16xi32>) -> tensor<1x16xi1>
          %100 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %101 = "arith.constant"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
          %102 = "arith.constant"() <{value = dense<0> : tensor<16xi32>}> : () -> tensor<16xi32>
          %103 = "tt.expand_dims"(%72) <{axis = 0 : i32}> : (tensor<16xi32>) -> tensor<1x16xi32>
          %104 = "tt.expand_dims"(%102) <{axis = 0 : i32}> : (tensor<16xi32>) -> tensor<1x16xi32>
          %105 = "arith.select"(%99, %103, %104) : (tensor<1x16xi1>, tensor<1x16xi32>, tensor<1x16xi32>) -> tensor<1x16xi32>
          %106 = "tt.expand_dims"(%33) <{axis = 0 : i32}> : (tensor<1xi32>) -> tensor<1x1xi32>
          %107 = "tt.broadcast"(%106) : (tensor<1x1xi32>) -> tensor<1x16xi32>
          %108 = "arith.extsi"(%107) : (tensor<1x16xi32>) -> tensor<1x16xi64>
          %109 = "arith.extsi"(%105) : (tensor<1x16xi32>) -> tensor<1x16xi64>
          %110 = "arith.addi"(%108, %109) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi64>
          %111 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
          %112 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
          %113 = "arith.constant"() <{value = dense<2147483647> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
          %114 = "arith.cmpi"(%110, %113) <{predicate = 3 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
          %115 = "arith.constant"() <{value = dense<-2147483648> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
          %116 = "arith.cmpi"(%110, %115) <{predicate = 5 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
          %117 = "arith.andi"(%114, %116) : (tensor<1x16xi1>, tensor<1x16xi1>) -> tensor<1x16xi1>
          %118 = "arith.addi"(%107, %105) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi32>, tensor<1x16xi32>) -> tensor<1x16xi32>
          %119 = "tt.reshape"(%118) : (tensor<1x16xi32>) -> tensor<16xi32>
          %120 = "arith.constant"() <{value = 1 : i32}> : () -> i32
          %121 = "arith.constant"() <{value = 1 : i32}> : () -> i32
          %122 = "arith.constant"() <{value = dense<1> : tensor<16xi32>}> : () -> tensor<16xi32>
          %123 = "arith.extsi"(%119) : (tensor<16xi32>) -> tensor<16xi64>
          %124 = "arith.extsi"(%122) : (tensor<16xi32>) -> tensor<16xi64>
          %125 = "arith.muli"(%123, %124) <{overflowFlags = #arith.overflow<none>}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi64>
          %126 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
          %127 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
          %128 = "arith.constant"() <{value = dense<2147483647> : tensor<16xi64>}> : () -> tensor<16xi64>
          %129 = "arith.cmpi"(%125, %128) <{predicate = 3 : i64}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi1>
          %130 = "arith.constant"() <{value = dense<-2147483648> : tensor<16xi64>}> : () -> tensor<16xi64>
          %131 = "arith.cmpi"(%125, %130) <{predicate = 5 : i64}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi1>
          %132 = "arith.andi"(%129, %131) : (tensor<16xi1>, tensor<16xi1>) -> tensor<16xi1>
          %133 = "arith.muli"(%119, %122) <{overflowFlags = #arith.overflow<none>}> : (tensor<16xi32>, tensor<16xi32>) -> tensor<16xi32>
          %134 = "tt.splat"(%arg2) : (!tt.ptr<i32>) -> tensor<16x!tt.ptr<i32>>
          %135 = "tt.addptr"(%134, %133) : (tensor<16x!tt.ptr<i32>>, tensor<16xi32>) -> tensor<16x!tt.ptr<i32>>
          %136 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %137 = "arith.constant"() <{value = dense<0> : tensor<16xi32>}> : () -> tensor<16xi32>
          %138 = "tt.load"(%135, %74, %137) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>}> : (tensor<16x!tt.ptr<i32>>, tensor<16xi1>, tensor<16xi32>) -> tensor<16xi32>
          %139 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
          %140 = "arith.constant"() <{value = dense<0.000000e+00> : tensor<16x16xf32>}> : () -> tensor<16x16xf32>
          %141 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %142 = "arith.constant"() <{value = 16 : i32}> : () -> i32
          %143 = "arith.bitcast"(%141) : (i32) -> i32
          %144 = "arith.bitcast"(%arg12) : (i32) -> i32
          %145 = "arith.bitcast"(%142) : (i32) -> i32
          %146 = "ub.poison"() <{value = #ub.poison}> : () -> i32
          %147 = "scf.for"(%143, %144, %145, %140) ({
          ^bb0(%arg15: i32, %arg16: tensor<16x16xf32>):
            %252 = "tt.make_range"() <{end = 16 : i32, start = 0 : i32}> : () -> tensor<16xi32>
            %253 = "tt.splat"(%arg15) : (i32) -> tensor<16xi32>
            %254 = "arith.extsi"(%253) : (tensor<16xi32>) -> tensor<16xi64>
            %255 = "arith.extsi"(%252) : (tensor<16xi32>) -> tensor<16xi64>
            %256 = "arith.addi"(%254, %255) <{overflowFlags = #arith.overflow<none>}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi64>
            %257 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
            %258 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
            %259 = "arith.constant"() <{value = dense<2147483647> : tensor<16xi64>}> : () -> tensor<16xi64>
            %260 = "arith.cmpi"(%256, %259) <{predicate = 3 : i64}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi1>
            %261 = "arith.constant"() <{value = dense<-2147483648> : tensor<16xi64>}> : () -> tensor<16xi64>
            %262 = "arith.cmpi"(%256, %261) <{predicate = 5 : i64}> : (tensor<16xi64>, tensor<16xi64>) -> tensor<16xi1>
            %263 = "arith.andi"(%260, %262) : (tensor<16xi1>, tensor<16xi1>) -> tensor<16xi1>
            %264 = "arith.addi"(%253, %252) <{overflowFlags = #arith.overflow<none>}> : (tensor<16xi32>, tensor<16xi32>) -> tensor<16xi32>
            %265 = "tt.splat"(%arg12) : (i32) -> tensor<16xi32>
            %266 = "arith.cmpi"(%264, %265) <{predicate = 2 : i64}> : (tensor<16xi32>, tensor<16xi32>) -> tensor<16xi1>
            %267 = "tt.expand_dims"(%138) <{axis = 1 : i32}> : (tensor<16xi32>) -> tensor<16x1xi32>
            %268 = "tt.splat"(%arg6) : (i32) -> tensor<16x1xi32>
            %269 = "arith.extsi"(%267) : (tensor<16x1xi32>) -> tensor<16x1xi64>
            %270 = "arith.extsi"(%268) : (tensor<16x1xi32>) -> tensor<16x1xi64>
            %271 = "arith.muli"(%269, %270) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi64>
            %272 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
            %273 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
            %274 = "arith.constant"() <{value = dense<2147483647> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
            %275 = "arith.cmpi"(%271, %274) <{predicate = 3 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
            %276 = "arith.constant"() <{value = dense<-2147483648> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
            %277 = "arith.cmpi"(%271, %276) <{predicate = 5 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
            %278 = "arith.andi"(%275, %277) : (tensor<16x1xi1>, tensor<16x1xi1>) -> tensor<16x1xi1>
            %279 = "arith.muli"(%267, %268) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi32>, tensor<16x1xi32>) -> tensor<16x1xi32>
            %280 = "tt.expand_dims"(%264) <{axis = 0 : i32}> : (tensor<16xi32>) -> tensor<1x16xi32>
            %281 = "arith.constant"() <{value = 1 : i32}> : () -> i32
            %282 = "arith.constant"() <{value = 1 : i32}> : () -> i32
            %283 = "arith.constant"() <{value = dense<1> : tensor<1x16xi32>}> : () -> tensor<1x16xi32>
            %284 = "arith.extsi"(%280) : (tensor<1x16xi32>) -> tensor<1x16xi64>
            %285 = "arith.extsi"(%283) : (tensor<1x16xi32>) -> tensor<1x16xi64>
            %286 = "arith.muli"(%284, %285) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi64>
            %287 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
            %288 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
            %289 = "arith.constant"() <{value = dense<2147483647> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
            %290 = "arith.cmpi"(%286, %289) <{predicate = 3 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
            %291 = "arith.constant"() <{value = dense<-2147483648> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
            %292 = "arith.cmpi"(%286, %291) <{predicate = 5 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
            %293 = "arith.andi"(%290, %292) : (tensor<1x16xi1>, tensor<1x16xi1>) -> tensor<1x16xi1>
            %294 = "arith.muli"(%280, %283) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi32>, tensor<1x16xi32>) -> tensor<1x16xi32>
            %295 = "tt.broadcast"(%279) : (tensor<16x1xi32>) -> tensor<16x16xi32>
            %296 = "tt.broadcast"(%294) : (tensor<1x16xi32>) -> tensor<16x16xi32>
            %297 = "arith.extsi"(%295) : (tensor<16x16xi32>) -> tensor<16x16xi64>
            %298 = "arith.extsi"(%296) : (tensor<16x16xi32>) -> tensor<16x16xi64>
            %299 = "arith.addi"(%297, %298) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi64>
            %300 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
            %301 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
            %302 = "arith.constant"() <{value = dense<2147483647> : tensor<16x16xi64>}> : () -> tensor<16x16xi64>
            %303 = "arith.cmpi"(%299, %302) <{predicate = 3 : i64}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi1>
            %304 = "arith.constant"() <{value = dense<-2147483648> : tensor<16x16xi64>}> : () -> tensor<16x16xi64>
            %305 = "arith.cmpi"(%299, %304) <{predicate = 5 : i64}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi1>
            %306 = "arith.andi"(%303, %305) : (tensor<16x16xi1>, tensor<16x16xi1>) -> tensor<16x16xi1>
            %307 = "arith.addi"(%295, %296) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x16xi32>, tensor<16x16xi32>) -> tensor<16x16xi32>
            %308 = "tt.splat"(%arg3) : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>>
            %309 = "tt.addptr"(%308, %307) : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi32>) -> tensor<16x16x!tt.ptr<f16>>
            %310 = "tt.expand_dims"(%74) <{axis = 1 : i32}> : (tensor<16xi1>) -> tensor<16x1xi1>
            %311 = "tt.expand_dims"(%266) <{axis = 0 : i32}> : (tensor<16xi1>) -> tensor<1x16xi1>
            %312 = "tt.broadcast"(%310) : (tensor<16x1xi1>) -> tensor<16x16xi1>
            %313 = "tt.broadcast"(%311) : (tensor<1x16xi1>) -> tensor<16x16xi1>
            %314 = "arith.andi"(%312, %313) : (tensor<16x16xi1>, tensor<16x16xi1>) -> tensor<16x16xi1>
            %315 = "arith.constant"() <{value = 0 : i32}> : () -> i32
            %316 = "arith.constant"() <{value = dense<0> : tensor<16x16xi32>}> : () -> tensor<16x16xi32>
            %317 = "arith.sitofp"(%316) : (tensor<16x16xi32>) -> tensor<16x16xf16>
            %318 = "tt.load"(%309, %314, %317) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>}> : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi1>, tensor<16x16xf16>) -> tensor<16x16xf16>
            %319 = "tt.splat"(%arg8) : (i32) -> tensor<1xi32>
            %320 = "arith.extsi"(%16) : (tensor<1xi32>) -> tensor<1xi64>
            %321 = "arith.extsi"(%319) : (tensor<1xi32>) -> tensor<1xi64>
            %322 = "arith.muli"(%320, %321) <{overflowFlags = #arith.overflow<none>}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
            %323 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
            %324 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
            %325 = "arith.constant"() <{value = dense<2147483647> : tensor<1xi64>}> : () -> tensor<1xi64>
            %326 = "arith.cmpi"(%322, %325) <{predicate = 3 : i64}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
            %327 = "arith.constant"() <{value = dense<-2147483648> : tensor<1xi64>}> : () -> tensor<1xi64>
            %328 = "arith.cmpi"(%322, %327) <{predicate = 5 : i64}> : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi1>
            %329 = "arith.andi"(%326, %328) : (tensor<1xi1>, tensor<1xi1>) -> tensor<1xi1>
            %330 = "arith.muli"(%16, %319) <{overflowFlags = #arith.overflow<none>}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
            %331 = "tt.expand_dims"(%264) <{axis = 1 : i32}> : (tensor<16xi32>) -> tensor<16x1xi32>
            %332 = "tt.splat"(%arg9) : (i32) -> tensor<16x1xi32>
            %333 = "arith.extsi"(%331) : (tensor<16x1xi32>) -> tensor<16x1xi64>
            %334 = "arith.extsi"(%332) : (tensor<16x1xi32>) -> tensor<16x1xi64>
            %335 = "arith.muli"(%333, %334) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi64>
            %336 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
            %337 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
            %338 = "arith.constant"() <{value = dense<2147483647> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
            %339 = "arith.cmpi"(%335, %338) <{predicate = 3 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
            %340 = "arith.constant"() <{value = dense<-2147483648> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
            %341 = "arith.cmpi"(%335, %340) <{predicate = 5 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
            %342 = "arith.andi"(%339, %341) : (tensor<16x1xi1>, tensor<16x1xi1>) -> tensor<16x1xi1>
            %343 = "arith.muli"(%331, %332) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi32>, tensor<16x1xi32>) -> tensor<16x1xi32>
            %344 = "tt.expand_dims"(%330) <{axis = 0 : i32}> : (tensor<1xi32>) -> tensor<1x1xi32>
            %345 = "tt.broadcast"(%344) : (tensor<1x1xi32>) -> tensor<16x1xi32>
            %346 = "arith.extsi"(%345) : (tensor<16x1xi32>) -> tensor<16x1xi64>
            %347 = "arith.extsi"(%343) : (tensor<16x1xi32>) -> tensor<16x1xi64>
            %348 = "arith.addi"(%346, %347) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi64>
            %349 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
            %350 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
            %351 = "arith.constant"() <{value = dense<2147483647> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
            %352 = "arith.cmpi"(%348, %351) <{predicate = 3 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
            %353 = "arith.constant"() <{value = dense<-2147483648> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
            %354 = "arith.cmpi"(%348, %353) <{predicate = 5 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
            %355 = "arith.andi"(%352, %354) : (tensor<16x1xi1>, tensor<16x1xi1>) -> tensor<16x1xi1>
            %356 = "arith.addi"(%345, %343) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi32>, tensor<16x1xi32>) -> tensor<16x1xi32>
            %357 = "tt.expand_dims"(%93) <{axis = 0 : i32}> : (tensor<16xi32>) -> tensor<1x16xi32>
            %358 = "arith.constant"() <{value = 1 : i32}> : () -> i32
            %359 = "arith.constant"() <{value = 1 : i32}> : () -> i32
            %360 = "arith.constant"() <{value = dense<1> : tensor<1x16xi32>}> : () -> tensor<1x16xi32>
            %361 = "arith.extsi"(%357) : (tensor<1x16xi32>) -> tensor<1x16xi64>
            %362 = "arith.extsi"(%360) : (tensor<1x16xi32>) -> tensor<1x16xi64>
            %363 = "arith.muli"(%361, %362) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi64>
            %364 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
            %365 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
            %366 = "arith.constant"() <{value = dense<2147483647> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
            %367 = "arith.cmpi"(%363, %366) <{predicate = 3 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
            %368 = "arith.constant"() <{value = dense<-2147483648> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
            %369 = "arith.cmpi"(%363, %368) <{predicate = 5 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
            %370 = "arith.andi"(%367, %369) : (tensor<1x16xi1>, tensor<1x16xi1>) -> tensor<1x16xi1>
            %371 = "arith.muli"(%357, %360) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi32>, tensor<1x16xi32>) -> tensor<1x16xi32>
            %372 = "tt.broadcast"(%356) : (tensor<16x1xi32>) -> tensor<16x16xi32>
            %373 = "tt.broadcast"(%371) : (tensor<1x16xi32>) -> tensor<16x16xi32>
            %374 = "arith.extsi"(%372) : (tensor<16x16xi32>) -> tensor<16x16xi64>
            %375 = "arith.extsi"(%373) : (tensor<16x16xi32>) -> tensor<16x16xi64>
            %376 = "arith.addi"(%374, %375) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi64>
            %377 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
            %378 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
            %379 = "arith.constant"() <{value = dense<2147483647> : tensor<16x16xi64>}> : () -> tensor<16x16xi64>
            %380 = "arith.cmpi"(%376, %379) <{predicate = 3 : i64}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi1>
            %381 = "arith.constant"() <{value = dense<-2147483648> : tensor<16x16xi64>}> : () -> tensor<16x16xi64>
            %382 = "arith.cmpi"(%376, %381) <{predicate = 5 : i64}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi1>
            %383 = "arith.andi"(%380, %382) : (tensor<16x16xi1>, tensor<16x16xi1>) -> tensor<16x16xi1>
            %384 = "arith.addi"(%372, %373) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x16xi32>, tensor<16x16xi32>) -> tensor<16x16xi32>
            %385 = "tt.splat"(%arg4) : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>>
            %386 = "tt.addptr"(%385, %384) : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi32>) -> tensor<16x16x!tt.ptr<f16>>
            %387 = "tt.expand_dims"(%266) <{axis = 1 : i32}> : (tensor<16xi1>) -> tensor<16x1xi1>
            %388 = "tt.expand_dims"(%95) <{axis = 0 : i32}> : (tensor<16xi1>) -> tensor<1x16xi1>
            %389 = "tt.broadcast"(%387) : (tensor<16x1xi1>) -> tensor<16x16xi1>
            %390 = "tt.broadcast"(%388) : (tensor<1x16xi1>) -> tensor<16x16xi1>
            %391 = "arith.andi"(%389, %390) : (tensor<16x16xi1>, tensor<16x16xi1>) -> tensor<16x16xi1>
            %392 = "arith.constant"() <{value = 0 : i32}> : () -> i32
            %393 = "arith.constant"() <{value = dense<0> : tensor<16x16xi32>}> : () -> tensor<16x16xi32>
            %394 = "arith.sitofp"(%393) : (tensor<16x16xi32>) -> tensor<16x16xf16>
            %395 = "tt.load"(%386, %391, %394) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>}> : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi1>, tensor<16x16xf16>) -> tensor<16x16xf16>
            %396 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
            %397 = "tt.dot"(%318, %395, %arg16) <{inputPrecision = 0 : i32, maxNumImpreciseAcc = 0 : i32}> : (tensor<16x16xf16>, tensor<16x16xf16>, tensor<16x16xf32>) -> tensor<16x16xf32>
            "scf.yield"(%397) : (tensor<16x16xf32>) -> ()
          }) : (i32, i32, i32, tensor<16x16xf32>) -> tensor<16x16xf32>
          %148 = "tt.expand_dims"(%138) <{axis = 1 : i32}> : (tensor<16xi32>) -> tensor<16x1xi32>
          %149 = "tt.splat"(%arg7) : (i32) -> tensor<16x1xi32>
          %150 = "arith.extsi"(%148) : (tensor<16x1xi32>) -> tensor<16x1xi64>
          %151 = "arith.extsi"(%149) : (tensor<16x1xi32>) -> tensor<16x1xi64>
          %152 = "arith.muli"(%150, %151) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi64>
          %153 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
          %154 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
          %155 = "arith.constant"() <{value = dense<2147483647> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
          %156 = "arith.cmpi"(%152, %155) <{predicate = 3 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
          %157 = "arith.constant"() <{value = dense<-2147483648> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
          %158 = "arith.cmpi"(%152, %157) <{predicate = 5 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
          %159 = "arith.andi"(%156, %158) : (tensor<16x1xi1>, tensor<16x1xi1>) -> tensor<16x1xi1>
          %160 = "arith.muli"(%148, %149) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi32>, tensor<16x1xi32>) -> tensor<16x1xi32>
          %161 = "tt.expand_dims"(%93) <{axis = 0 : i32}> : (tensor<16xi32>) -> tensor<1x16xi32>
          %162 = "arith.constant"() <{value = 1 : i32}> : () -> i32
          %163 = "arith.constant"() <{value = 1 : i32}> : () -> i32
          %164 = "arith.constant"() <{value = dense<1> : tensor<1x16xi32>}> : () -> tensor<1x16xi32>
          %165 = "arith.extsi"(%161) : (tensor<1x16xi32>) -> tensor<1x16xi64>
          %166 = "arith.extsi"(%164) : (tensor<1x16xi32>) -> tensor<1x16xi64>
          %167 = "arith.muli"(%165, %166) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi64>
          %168 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
          %169 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
          %170 = "arith.constant"() <{value = dense<2147483647> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
          %171 = "arith.cmpi"(%167, %170) <{predicate = 3 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
          %172 = "arith.constant"() <{value = dense<-2147483648> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
          %173 = "arith.cmpi"(%167, %172) <{predicate = 5 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
          %174 = "arith.andi"(%171, %173) : (tensor<1x16xi1>, tensor<1x16xi1>) -> tensor<1x16xi1>
          %175 = "arith.muli"(%161, %164) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi32>, tensor<1x16xi32>) -> tensor<1x16xi32>
          %176 = "tt.broadcast"(%160) : (tensor<16x1xi32>) -> tensor<16x16xi32>
          %177 = "tt.broadcast"(%175) : (tensor<1x16xi32>) -> tensor<16x16xi32>
          %178 = "arith.extsi"(%176) : (tensor<16x16xi32>) -> tensor<16x16xi64>
          %179 = "arith.extsi"(%177) : (tensor<16x16xi32>) -> tensor<16x16xi64>
          %180 = "arith.addi"(%178, %179) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi64>
          %181 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
          %182 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
          %183 = "arith.constant"() <{value = dense<2147483647> : tensor<16x16xi64>}> : () -> tensor<16x16xi64>
          %184 = "arith.cmpi"(%180, %183) <{predicate = 3 : i64}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi1>
          %185 = "arith.constant"() <{value = dense<-2147483648> : tensor<16x16xi64>}> : () -> tensor<16x16xi64>
          %186 = "arith.cmpi"(%180, %185) <{predicate = 5 : i64}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi1>
          %187 = "arith.andi"(%184, %186) : (tensor<16x16xi1>, tensor<16x16xi1>) -> tensor<16x16xi1>
          %188 = "arith.addi"(%176, %177) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x16xi32>, tensor<16x16xi32>) -> tensor<16x16xi32>
          %189 = "tt.splat"(%arg5) : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>>
          %190 = "tt.addptr"(%189, %188) : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi32>) -> tensor<16x16x!tt.ptr<f16>>
          %191 = "tt.expand_dims"(%74) <{axis = 1 : i32}> : (tensor<16xi1>) -> tensor<16x1xi1>
          %192 = "tt.expand_dims"(%95) <{axis = 0 : i32}> : (tensor<16xi1>) -> tensor<1x16xi1>
          %193 = "tt.broadcast"(%191) : (tensor<16x1xi1>) -> tensor<16x16xi1>
          %194 = "tt.broadcast"(%192) : (tensor<1x16xi1>) -> tensor<16x16xi1>
          %195 = "arith.andi"(%193, %194) : (tensor<16x16xi1>, tensor<16x16xi1>) -> tensor<16x16xi1>
          %196 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %197 = "arith.constant"() <{value = dense<0> : tensor<16x16xi32>}> : () -> tensor<16x16xi32>
          %198 = "arith.sitofp"(%197) : (tensor<16x16xi32>) -> tensor<16x16xf16>
          %199 = "tt.load"(%190, %195, %198) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operandSegmentSizes = array<i32: 1, 1, 1>}> : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi1>, tensor<16x16xf16>) -> tensor<16x16xf16>
          %200 = "tt.reshape"(%99) : (tensor<1x16xi1>) -> tensor<16x1xi1>
          %201 = "tt.broadcast"(%200) : (tensor<16x1xi1>) -> tensor<16x16xi1>
          %202 = "arith.truncf"(%147) : (tensor<16x16xf32>) -> tensor<16x16xf16>
          %203 = "arith.select"(%201, %202, %199) : (tensor<16x16xi1>, tensor<16x16xf16>, tensor<16x16xf16>) -> tensor<16x16xf16>
          %204 = "tt.expand_dims"(%138) <{axis = 1 : i32}> : (tensor<16xi32>) -> tensor<16x1xi32>
          %205 = "tt.splat"(%arg7) : (i32) -> tensor<16x1xi32>
          %206 = "arith.extsi"(%204) : (tensor<16x1xi32>) -> tensor<16x1xi64>
          %207 = "arith.extsi"(%205) : (tensor<16x1xi32>) -> tensor<16x1xi64>
          %208 = "arith.muli"(%206, %207) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi64>
          %209 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
          %210 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
          %211 = "arith.constant"() <{value = dense<2147483647> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
          %212 = "arith.cmpi"(%208, %211) <{predicate = 3 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
          %213 = "arith.constant"() <{value = dense<-2147483648> : tensor<16x1xi64>}> : () -> tensor<16x1xi64>
          %214 = "arith.cmpi"(%208, %213) <{predicate = 5 : i64}> : (tensor<16x1xi64>, tensor<16x1xi64>) -> tensor<16x1xi1>
          %215 = "arith.andi"(%212, %214) : (tensor<16x1xi1>, tensor<16x1xi1>) -> tensor<16x1xi1>
          %216 = "arith.muli"(%204, %205) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x1xi32>, tensor<16x1xi32>) -> tensor<16x1xi32>
          %217 = "tt.expand_dims"(%93) <{axis = 0 : i32}> : (tensor<16xi32>) -> tensor<1x16xi32>
          %218 = "arith.constant"() <{value = 1 : i32}> : () -> i32
          %219 = "arith.constant"() <{value = 1 : i32}> : () -> i32
          %220 = "arith.constant"() <{value = dense<1> : tensor<1x16xi32>}> : () -> tensor<1x16xi32>
          %221 = "arith.extsi"(%217) : (tensor<1x16xi32>) -> tensor<1x16xi64>
          %222 = "arith.extsi"(%220) : (tensor<1x16xi32>) -> tensor<1x16xi64>
          %223 = "arith.muli"(%221, %222) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi64>
          %224 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
          %225 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
          %226 = "arith.constant"() <{value = dense<2147483647> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
          %227 = "arith.cmpi"(%223, %226) <{predicate = 3 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
          %228 = "arith.constant"() <{value = dense<-2147483648> : tensor<1x16xi64>}> : () -> tensor<1x16xi64>
          %229 = "arith.cmpi"(%223, %228) <{predicate = 5 : i64}> : (tensor<1x16xi64>, tensor<1x16xi64>) -> tensor<1x16xi1>
          %230 = "arith.andi"(%227, %229) : (tensor<1x16xi1>, tensor<1x16xi1>) -> tensor<1x16xi1>
          %231 = "arith.muli"(%217, %220) <{overflowFlags = #arith.overflow<none>}> : (tensor<1x16xi32>, tensor<1x16xi32>) -> tensor<1x16xi32>
          %232 = "tt.broadcast"(%216) : (tensor<16x1xi32>) -> tensor<16x16xi32>
          %233 = "tt.broadcast"(%231) : (tensor<1x16xi32>) -> tensor<16x16xi32>
          %234 = "arith.extsi"(%232) : (tensor<16x16xi32>) -> tensor<16x16xi64>
          %235 = "arith.extsi"(%233) : (tensor<16x16xi32>) -> tensor<16x16xi64>
          %236 = "arith.addi"(%234, %235) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi64>
          %237 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
          %238 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
          %239 = "arith.constant"() <{value = dense<2147483647> : tensor<16x16xi64>}> : () -> tensor<16x16xi64>
          %240 = "arith.cmpi"(%236, %239) <{predicate = 3 : i64}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi1>
          %241 = "arith.constant"() <{value = dense<-2147483648> : tensor<16x16xi64>}> : () -> tensor<16x16xi64>
          %242 = "arith.cmpi"(%236, %241) <{predicate = 5 : i64}> : (tensor<16x16xi64>, tensor<16x16xi64>) -> tensor<16x16xi1>
          %243 = "arith.andi"(%240, %242) : (tensor<16x16xi1>, tensor<16x16xi1>) -> tensor<16x16xi1>
          %244 = "arith.addi"(%232, %233) <{overflowFlags = #arith.overflow<none>}> : (tensor<16x16xi32>, tensor<16x16xi32>) -> tensor<16x16xi32>
          %245 = "tt.splat"(%arg5) : (!tt.ptr<f16>) -> tensor<16x16x!tt.ptr<f16>>
          %246 = "tt.addptr"(%245, %244) : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi32>) -> tensor<16x16x!tt.ptr<f16>>
          %247 = "tt.expand_dims"(%74) <{axis = 1 : i32}> : (tensor<16xi1>) -> tensor<16x1xi1>
          %248 = "tt.expand_dims"(%95) <{axis = 0 : i32}> : (tensor<16xi1>) -> tensor<1x16xi1>
          %249 = "tt.broadcast"(%247) : (tensor<16x1xi1>) -> tensor<16x16xi1>
          %250 = "tt.broadcast"(%248) : (tensor<1x16xi1>) -> tensor<16x16xi1>
          %251 = "arith.andi"(%249, %250) : (tensor<16x16xi1>, tensor<16x16xi1>) -> tensor<16x16xi1>
          "tt.store"(%246, %203, %251) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32}> : (tensor<16x16x!tt.ptr<f16>>, tensor<16x16xf16>, tensor<16x16xi1>) -> ()
          "scf.yield"() : () -> ()
        }) : (i32, i32, i32) -> ()
        "scf.yield"() : () -> ()
      }) : (i32, i32, i32) -> ()
      "scf.yield"() : () -> ()
    }, {
      "scf.yield"() : () -> ()
    }) : (tensor<1xi1>) -> ()
    "tt.return"() : () -> ()
  }) {noinline = false} : () -> ()
  "tt.func"() <{function_type = () -> tensor<1xi32>, sym_name = "zeros____(0, 0)cconstexpr_1__(1,)cconstexpr_int32_", sym_visibility = "private"}> ({
    %0 = "arith.constant"() <{value = 0 : i32}> : () -> i32
    %1 = "arith.constant"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
    "tt.return"(%1) : (tensor<1xi32>) -> ()
  ^bb1:  // no predecessors
    %2 = "ub.poison"() <{value = #ub.poison}> : () -> tensor<1xi32>
    "tt.return"(%2) : (tensor<1xi32>) -> ()
  }) {noinline = false} : () -> ()
}) : () -> ()

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(inline{default-pipeline=canonicalize inlining-threshold=4294967295 max-iterations=4 }, triton-rewrite-tensor-pointer, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, triton-combine, triton-reorder-broadcast, cse, symbol-dce, triton-loop-unroll)",
      disable_threading: false,
      verify_each: true
    }
  }
#-}
/tmp/torchinductor_jansel/qy/cqy6fygklghcmief2vg7i3kbi7xecl7hnagpfyk4a25mxnrizjw4.py:8:0: error: A signal was caught while processing the MLIR module:reproducer generated at `std::errs, please share the reproducer above with Triton project.`; marking pass as failed
zsh: IOT instruction (core dumped)  pytest test -vs

@@ -1450,7 +1454,6 @@ def test_moe_matmul_ogs(self):
helion_kernel_args,
mod.moe_matmul_ogs_reference(*args),
block_sizes=[[16, 16], 16],
l2_grouping=4,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config is a no-op on this kernel due to the usage of hl.grid() (I removed it in the next PR in the stack.)

stack-info: PR: #121, branch: jansel/stack/18
@jansel jansel force-pushed the jansel/stack/18 branch from 799834c to 53506b5 Compare June 2, 2025 03:23
@jansel jansel merged commit 6838f04 into main Jun 2, 2025
4 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants