Skip to content

Commit 50ab6f6

Browse files
committed
add the code for heaviside op
1 parent 38d5f99 commit 50ab6f6

File tree

8 files changed

+187
-0
lines changed

8 files changed

+187
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13077,6 +13077,53 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
1307713077
let hasFolder = 1;
1307813078
}
1307913079

13080+
def Torch_AtenHeavisideOp : Torch_Op<"aten.heaviside", [
13081+
AllowsTypeRefinement,
13082+
HasValueSemantics,
13083+
ReadOnly
13084+
]> {
13085+
let summary = "Generated op for `aten::heaviside : (Tensor, Tensor) -> (Tensor)`";
13086+
let arguments = (ins
13087+
AnyTorchTensorType:$self,
13088+
AnyTorchTensorType:$values
13089+
);
13090+
let results = (outs
13091+
AnyTorchOptionalTensorType:$result
13092+
);
13093+
let hasCustomAssemblyFormat = 1;
13094+
let extraClassDefinition = [{
13095+
ParseResult AtenHeavisideOp::parse(OpAsmParser &parser, OperationState &result) {
13096+
return parseDefaultTorchOp(parser, result, 2, 1);
13097+
}
13098+
void AtenHeavisideOp::print(OpAsmPrinter &printer) {
13099+
printDefaultTorchOp(printer, *this, 2, 1);
13100+
}
13101+
}];
13102+
}
13103+
13104+
def Torch_AtenHeaviside_Op : Torch_Op<"aten.heaviside_", [
13105+
IsTrailingUnderscoreInplaceVariant,
13106+
AllowsTypeRefinement
13107+
]> {
13108+
let summary = "Generated op for `aten::heaviside_ : (Tensor, Tensor) -> (Tensor)`";
13109+
let arguments = (ins
13110+
Torch_NonValueTensorType:$self,
13111+
Torch_NonValueTensorType:$values
13112+
);
13113+
let results = (outs
13114+
AnyTorchOptionalNonValueTensorType:$result
13115+
);
13116+
let hasCustomAssemblyFormat = 1;
13117+
let extraClassDefinition = [{
13118+
ParseResult AtenHeaviside_Op::parse(OpAsmParser &parser, OperationState &result) {
13119+
return parseDefaultTorchOp(parser, result, 2, 1);
13120+
}
13121+
void AtenHeaviside_Op::print(OpAsmPrinter &printer) {
13122+
printDefaultTorchOp(printer, *this, 2, 1);
13123+
}
13124+
}];
13125+
}
13126+
1308013127
def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [
1308113128
AllowsTypeRefinement,
1308213129
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9671,6 +9671,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
96719671
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
96729672
" return %0 : !torch.list<int>\n"
96739673
" }\n"
9674+
" func.func @\"__torch_mlir_shape_fn.aten.heaviside\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9675+
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
9676+
" return %0 : !torch.list<int>\n"
9677+
" }\n"
96749678
" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
96759679
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
96769680
" return %0 : !torch.list<int>\n"
@@ -15192,6 +15196,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1519215196
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1519315197
" return %4 : !torch.int\n"
1519415198
" }\n"
15199+
" func.func @\"__torch_mlir_dtype_fn.aten.heaviside\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
15200+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15201+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15202+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
15203+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
15204+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
15205+
" return %4 : !torch.int\n"
15206+
" }\n"
1519515207
" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.int {\n"
1519615208
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1519715209
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11068,6 +11068,76 @@ class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
1106811068
};
1106911069
} // namespace
1107011070

11071+
namespace {
11072+
// Decomposes aten.heaviside op into
11073+
// using aten.eq, aten.lt, aten.logical_or, aten.where
11074+
// Heaviside(x, y) returns:
11075+
// 0 if x < 0
11076+
// y if x == 0
11077+
// 1 if x > 0
11078+
class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
11079+
public:
11080+
using OpRewritePattern::OpRewritePattern;
11081+
LogicalResult matchAndRewrite(AtenHeavisideOp op,
11082+
PatternRewriter &rewriter) const override {
11083+
auto input = op.getSelf();
11084+
auto value = op.getValues();
11085+
auto loc = op.getLoc();
11086+
auto inputTy = dyn_cast<BaseTensorType>(input.getType());
11087+
if (!inputTy || !inputTy.hasDtype() || !inputTy.hasSizes())
11088+
return rewriter.notifyMatchFailure(op, "input must have dtype and size.");
11089+
11090+
auto valueTy = dyn_cast<BaseTensorType>(value.getType());
11091+
if (!valueTy || !valueTy.hasDtype() || !valueTy.hasSizes())
11092+
return rewriter.notifyMatchFailure(op, "value must have dtype and size.");
11093+
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
11094+
SmallVector<int64_t> broadcastShape;
11095+
SmallVector<Value> broadcastShapeValue;
11096+
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
11097+
broadcastShapeValue);
11098+
11099+
auto broadcastType = ValueTensorType::get(
11100+
op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype());
11101+
auto boolBroadcastType = ValueTensorType::get(
11102+
op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type());
11103+
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
11104+
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
11105+
broadcastShapeValue);
11106+
auto inputBroadcasted = rewriter.create<AtenBroadcastToOp>(
11107+
loc, broadcastType, input, indexBroadcastShapeTorchList);
11108+
auto valueBroadcasted = rewriter.create<AtenBroadcastToOp>(
11109+
loc, broadcastType, value, indexBroadcastShapeTorchList);
11110+
11111+
Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
11112+
resultTy.getDtype());
11113+
Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
11114+
resultTy.getDtype());
11115+
// Compute mask: input == 0
11116+
auto inputEqZero = rewriter
11117+
.create<AtenEqScalarOp>(loc, boolBroadcastType,
11118+
inputBroadcasted, zero)
11119+
->getResult(0);
11120+
// Compute mask: input < 0
11121+
auto inputLtZero = rewriter.create<AtenLtScalarOp>(loc, boolBroadcastType,
11122+
inputBroadcasted, zero);
11123+
// Compute mask: isnan(input)
11124+
auto isNan =
11125+
rewriter.create<AtenIsnanOp>(loc, boolBroadcastType, inputBroadcasted);
11126+
// Combine: input < 0 || isnan(input)
11127+
auto inputNegativeOrNan = rewriter.create<AtenLogicalOrOp>(
11128+
loc, boolBroadcastType, inputLtZero, isNan);
11129+
// Select 0 if input < 0 or input is nan, else 1
11130+
auto zerosOrOnes = rewriter.create<AtenWhereScalarOp>(
11131+
loc, resultTy, inputNegativeOrNan, zero, one);
11132+
// Final result: if input == 0, take from valueBroadcasted, else take from
11133+
// zerosOrOnes
11134+
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resultTy, inputEqZero,
11135+
valueBroadcasted, zerosOrOnes);
11136+
return success();
11137+
}
11138+
};
11139+
} // namespace
11140+
1107111141
namespace {
1107211142
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
1107311143
// `torch.to.dtype`.
@@ -12291,6 +12361,7 @@ class DecomposeComplexOpsPass
1229112361
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
1229212362
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
1229312363
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
12364+
addPatternIfTargetOpIsIllegal<DecomposeAtenHeaviside>(patterns);
1229412365
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
1229512366
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
1229612367
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
460460
target.addIllegalOp<AtenSquareOp>();
461461
target.addIllegalOp<AtenVarOp>();
462462
target.addIllegalOp<AtenStdOp>();
463+
target.addIllegalOp<AtenHeavisideOp>();
463464
target.addIllegalOp<Aten_UnsafeViewOp>();
464465
target.addIllegalOp<Aten_ReshapeAliasOp>();
465466
target.addIllegalOp<AtenBernoulliOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,7 @@
12491249
"ElementwiseToDtypeI64ToI8Module_basic",
12501250
"ElementwiseToDtypeIdentityModule_basic",
12511251
"ElementwiseUnaryModule_basic",
1252+
"ElementwiseHeavisideModule_basic",
12521253
"EmptyLikeMemoryFormatModule_basic",
12531254
"EmptyLikeModule_defaultDtype",
12541255
"EmptyLikeModule_falsePinMemory",
@@ -1849,6 +1850,7 @@
18491850
"ElementwiseFracModule_basic",
18501851
"ElementwiseLdexpModule_basic",
18511852
"ElementwiseSignbitIntModule_basic",
1853+
"ElementwiseHeavisideModule_basic",
18521854
"Exp2StaticIntModule_basic",
18531855
"MaxPool1dEmptyStrideStaticModule_basic",
18541856
"MaxPool1dStaticCeilModeTrueModule_basic",
@@ -2958,6 +2960,8 @@
29582960
"GtFloatIntModule_basic",
29592961
"GtIntModule_basic",
29602962
"HardtanhBackward_basic",
2963+
"ElementwiseHeavisideModule_basic",
2964+
"ElementwiseHeavisideIntModule_basic",
29612965
"HstackBasicComplexModule_basic",
29622966
"HstackBasicFloatModule_basic",
29632967
"HstackBasicIntFloatModule_basic",
@@ -3958,6 +3962,8 @@
39583962
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
39593963
"ElementwiseRreluWithNoiseTrainModule_basic",
39603964
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
3965+
"ElementwiseHeavisideModule_basic",
3966+
"ElementwiseHeavisideIntModule_basic",
39613967
"RreluWithNoiseBackwardEvalModule_basic",
39623968
"RreluWithNoiseBackwardEvalStaticModule_basic",
39633969
"RreluWithNoiseBackwardTrainModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,6 +1772,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot
17721772
def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]:
17731773
return upstream_shape_functions.broadcast(condition, other)
17741774

1775+
def aten〇heaviside〡shape(self: List[int], values: List[int]) -> List[int]:
1776+
return upstream_shape_functions.broadcast(self, values)
1777+
17751778
def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]:
17761779
return upstream_shape_functions.unary(self)
17771780

@@ -5069,6 +5072,14 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel
50695072
dtypes = [get_dtype_of_scalar(self), other_dtype]
50705073
return promote_dtypes(ranks, dtypes)
50715074

5075+
def aten〇heaviside〡dtype(self_rank_dtype: Tuple[int, int], values_rank_dtype: Tuple[int, int]) -> int:
5076+
self_rank,self_dtype = self_rank_dtype
5077+
values_rank,values_dtype = values_rank_dtype
5078+
ranks: List[Optional[int]] = [self_rank, values_rank]
5079+
dtypes = [self_dtype, values_dtype]
5080+
promoted_dtype = promote_dtypes(ranks, dtypes)
5081+
return promoted_dtype
5082+
50725083
@check_dtype_function(
50735084
_check_tensors_with_the_same_dtype(num_of_tensors=1))
50745085
def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int:

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,7 @@ def emit_with_mutating_variants(key, **kwargs):
958958
emit(
959959
"aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True
960960
)
961+
emit_with_mutating_variants("aten::heaviside : (Tensor, Tensor) -> (Tensor)")
961962
emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)")
962963
emit(
963964
"aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,44 @@ def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
298298
# ==============================================================================
299299

300300

301+
class ElementwiseHeavisideModule(torch.nn.Module):
302+
def __init__(self):
303+
super().__init__()
304+
305+
@export
306+
@annotate_args([None, ([5], torch.float32, True), ([1], torch.float32, True)])
307+
def forward(self, x, values):
308+
return torch.heaviside(x, values)
309+
310+
311+
@register_test_case(module_factory=lambda: ElementwiseHeavisideModule())
312+
def ElementwiseHeavisideModule_basic(module, tu: TestUtils):
313+
module.forward(
314+
torch.tensor([1.0, -2.0, torch.inf, torch.nan, -torch.inf]), torch.tensor([5.0])
315+
)
316+
317+
318+
class ElementwiseHeavisideIntModule(torch.nn.Module):
319+
def __init__(self):
320+
super().__init__()
321+
322+
@export
323+
@annotate_args([None, ([-1, -1], torch.int32, True), ([-1], torch.int32, True)])
324+
def forward(self, x, values):
325+
return torch.heaviside(x, values)
326+
327+
328+
@register_test_case(module_factory=lambda: ElementwiseHeavisideIntModule())
329+
def ElementwiseHeavisideIntModule_basic(module, tu: TestUtils):
330+
module.forward(
331+
tu.randint(5, 1, low=-100, high=1000).to(torch.int32),
332+
tu.randint(1, low=-100, high=1000).to(torch.int32),
333+
)
334+
335+
336+
# ==============================================================================
337+
338+
301339
class ElementwiseLtIntScalarModule(torch.nn.Module):
302340
def __init__(self):
303341
super().__init__()

0 commit comments

Comments
 (0)