Skip to content

Commit f3f8301

Browse files
committed
add the code for heaviside op
1 parent 716303a commit f3f8301

File tree

8 files changed

+187
-0
lines changed

8 files changed

+187
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12931,6 +12931,53 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
1293112931
let hasFolder = 1;
1293212932
}
1293312933

12934+
def Torch_AtenHeavisideOp : Torch_Op<"aten.heaviside", [
12935+
AllowsTypeRefinement,
12936+
HasValueSemantics,
12937+
ReadOnly
12938+
]> {
12939+
let summary = "Generated op for `aten::heaviside : (Tensor, Tensor) -> (Tensor)`";
12940+
let arguments = (ins
12941+
AnyTorchTensorType:$self,
12942+
AnyTorchTensorType:$values
12943+
);
12944+
let results = (outs
12945+
AnyTorchOptionalTensorType:$result
12946+
);
12947+
let hasCustomAssemblyFormat = 1;
12948+
let extraClassDefinition = [{
12949+
ParseResult AtenHeavisideOp::parse(OpAsmParser &parser, OperationState &result) {
12950+
return parseDefaultTorchOp(parser, result, 2, 1);
12951+
}
12952+
void AtenHeavisideOp::print(OpAsmPrinter &printer) {
12953+
printDefaultTorchOp(printer, *this, 2, 1);
12954+
}
12955+
}];
12956+
}
12957+
12958+
def Torch_AtenHeaviside_Op : Torch_Op<"aten.heaviside_", [
12959+
IsTrailingUnderscoreInplaceVariant,
12960+
AllowsTypeRefinement
12961+
]> {
12962+
let summary = "Generated op for `aten::heaviside_ : (Tensor, Tensor) -> (Tensor)`";
12963+
let arguments = (ins
12964+
Torch_NonValueTensorType:$self,
12965+
Torch_NonValueTensorType:$values
12966+
);
12967+
let results = (outs
12968+
AnyTorchOptionalNonValueTensorType:$result
12969+
);
12970+
let hasCustomAssemblyFormat = 1;
12971+
let extraClassDefinition = [{
12972+
ParseResult AtenHeaviside_Op::parse(OpAsmParser &parser, OperationState &result) {
12973+
return parseDefaultTorchOp(parser, result, 2, 1);
12974+
}
12975+
void AtenHeaviside_Op::print(OpAsmPrinter &printer) {
12976+
printDefaultTorchOp(printer, *this, 2, 1);
12977+
}
12978+
}];
12979+
}
12980+
1293412981
def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [
1293512982
AllowsTypeRefinement,
1293612983
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9650,6 +9650,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
96509650
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
96519651
" return %0 : !torch.list<int>\n"
96529652
" }\n"
9653+
" func.func @\"__torch_mlir_shape_fn.aten.heaviside\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9654+
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
9655+
" return %0 : !torch.list<int>\n"
9656+
" }\n"
96539657
" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
96549658
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
96559659
" return %0 : !torch.list<int>\n"
@@ -15122,6 +15126,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1512215126
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1512315127
" return %4 : !torch.int\n"
1512415128
" }\n"
15129+
" func.func @\"__torch_mlir_dtype_fn.aten.heaviside\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
15130+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15131+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15132+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
15133+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
15134+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
15135+
" return %4 : !torch.int\n"
15136+
" }\n"
1512515137
" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.int {\n"
1512615138
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1512715139
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10971,6 +10971,76 @@ class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
1097110971
};
1097210972
} // namespace
1097310973

10974+
namespace {
10975+
// Decomposes aten.heaviside op into
10976+
// using aten.eq, aten.lt, aten.logical_or, aten.where
10977+
// Heaviside(x, y) returns:
10978+
// 0 if x < 0
10979+
// y if x == 0
10980+
// 1 if x > 0
10981+
class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
10982+
public:
10983+
using OpRewritePattern::OpRewritePattern;
10984+
LogicalResult matchAndRewrite(AtenHeavisideOp op,
10985+
PatternRewriter &rewriter) const override {
10986+
auto input = op.getSelf();
10987+
auto value = op.getValues();
10988+
auto loc = op.getLoc();
10989+
auto inputTy = dyn_cast<BaseTensorType>(input.getType());
10990+
if (!inputTy || !inputTy.hasDtype() || !inputTy.hasSizes())
10991+
return rewriter.notifyMatchFailure(op, "input must have dtype and size.");
10992+
10993+
auto valueTy = dyn_cast<BaseTensorType>(value.getType());
10994+
if (!valueTy || !valueTy.hasDtype() || !valueTy.hasSizes())
10995+
return rewriter.notifyMatchFailure(op, "value must have dtype and size.");
10996+
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
10997+
SmallVector<int64_t> broadcastShape;
10998+
SmallVector<Value> broadcastShapeValue;
10999+
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
11000+
broadcastShapeValue);
11001+
11002+
auto broadcastType = ValueTensorType::get(
11003+
op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype());
11004+
auto boolBroadcastType = ValueTensorType::get(
11005+
op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type());
11006+
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
11007+
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
11008+
broadcastShapeValue);
11009+
auto inputBroadcasted = rewriter.create<AtenBroadcastToOp>(
11010+
loc, broadcastType, input, indexBroadcastShapeTorchList);
11011+
auto valueBroadcasted = rewriter.create<AtenBroadcastToOp>(
11012+
loc, broadcastType, value, indexBroadcastShapeTorchList);
11013+
11014+
Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
11015+
resultTy.getDtype());
11016+
Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
11017+
resultTy.getDtype());
11018+
// Compute mask: input == 0
11019+
auto inputEqZero = rewriter
11020+
.create<AtenEqScalarOp>(loc, boolBroadcastType,
11021+
inputBroadcasted, zero)
11022+
->getResult(0);
11023+
// Compute mask: input < 0
11024+
auto inputLtZero = rewriter.create<AtenLtScalarOp>(loc, boolBroadcastType,
11025+
inputBroadcasted, zero);
11026+
// Compute mask: isnan(input)
11027+
auto isNan =
11028+
rewriter.create<AtenIsnanOp>(loc, boolBroadcastType, inputBroadcasted);
11029+
// Combine: input < 0 || isnan(input)
11030+
auto inputNegativeOrNan = rewriter.create<AtenLogicalOrOp>(
11031+
loc, boolBroadcastType, inputLtZero, isNan);
11032+
// Select 0 if input < 0 or input is nan, else 1
11033+
auto zerosOrOnes = rewriter.create<AtenWhereScalarOp>(
11034+
loc, resultTy, inputNegativeOrNan, zero, one);
11035+
// Final result: if input == 0, take from valueBroadcasted, else take from
11036+
// zerosOrOnes
11037+
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resultTy, inputEqZero,
11038+
valueBroadcasted, zerosOrOnes);
11039+
return success();
11040+
}
11041+
};
11042+
} // namespace
11043+
1097411044
namespace {
1097511045
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
1097611046
// `torch.to.dtype`.
@@ -12143,6 +12213,7 @@ class DecomposeComplexOpsPass
1214312213
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
1214412214
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
1214512215
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
12216+
addPatternIfTargetOpIsIllegal<DecomposeAtenHeaviside>(patterns);
1214612217
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
1214712218
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
1214812219
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
458458
target.addIllegalOp<AtenSquareOp>();
459459
target.addIllegalOp<AtenVarOp>();
460460
target.addIllegalOp<AtenStdOp>();
461+
target.addIllegalOp<AtenHeavisideOp>();
461462
target.addIllegalOp<Aten_UnsafeViewOp>();
462463
target.addIllegalOp<Aten_ReshapeAliasOp>();
463464
target.addIllegalOp<AtenBernoulliOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,7 @@
12491249
"ElementwiseToDtypeI64ToI8Module_basic",
12501250
"ElementwiseToDtypeIdentityModule_basic",
12511251
"ElementwiseUnaryModule_basic",
1252+
"ElementwiseHeavisideModule_basic",
12521253
"EmptyLikeMemoryFormatModule_basic",
12531254
"EmptyLikeModule_defaultDtype",
12541255
"EmptyLikeModule_falsePinMemory",
@@ -1849,6 +1850,7 @@
18491850
"ElementwiseFracModule_basic",
18501851
"ElementwiseLdexpModule_basic",
18511852
"ElementwiseSignbitIntModule_basic",
1853+
"ElementwiseHeavisideModule_basic",
18521854
"Exp2StaticIntModule_basic",
18531855
"MaxPool1dEmptyStrideStaticModule_basic",
18541856
"MaxPool1dStaticCeilModeTrueModule_basic",
@@ -2955,6 +2957,8 @@
29552957
"GtFloatIntModule_basic",
29562958
"GtIntModule_basic",
29572959
"HardtanhBackward_basic",
2960+
"ElementwiseHeavisideModule_basic",
2961+
"ElementwiseHeavisideIntModule_basic",
29582962
"HstackBasicComplexModule_basic",
29592963
"HstackBasicFloatModule_basic",
29602964
"HstackBasicIntFloatModule_basic",
@@ -3946,6 +3950,8 @@
39463950
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
39473951
"ElementwiseRreluWithNoiseTrainModule_basic",
39483952
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
3953+
"ElementwiseHeavisideModule_basic",
3954+
"ElementwiseHeavisideIntModule_basic",
39493955
"RreluWithNoiseBackwardEvalModule_basic",
39503956
"RreluWithNoiseBackwardEvalStaticModule_basic",
39513957
"RreluWithNoiseBackwardTrainModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,6 +1757,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot
17571757
def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]:
17581758
return upstream_shape_functions.broadcast(condition, other)
17591759

1760+
def aten〇heaviside〡shape(self: List[int], values: List[int]) -> List[int]:
1761+
return upstream_shape_functions.broadcast(self, values)
1762+
17601763
def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]:
17611764
return upstream_shape_functions.unary(self)
17621765

@@ -5031,6 +5034,14 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel
50315034
dtypes = [get_dtype_of_scalar(self), other_dtype]
50325035
return promote_dtypes(ranks, dtypes)
50335036

5037+
def aten〇heaviside〡dtype(self_rank_dtype: Tuple[int, int], values_rank_dtype: Tuple[int, int]) -> int:
5038+
self_rank,self_dtype = self_rank_dtype
5039+
values_rank,values_dtype = values_rank_dtype
5040+
ranks: List[Optional[int]] = [self_rank, values_rank]
5041+
dtypes = [self_dtype, values_dtype]
5042+
promoted_dtype = promote_dtypes(ranks, dtypes)
5043+
return promoted_dtype
5044+
50345045
@check_dtype_function(
50355046
_check_tensors_with_the_same_dtype(num_of_tensors=1))
50365047
def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int:

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,7 @@ def emit_with_mutating_variants(key, **kwargs):
951951
emit(
952952
"aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True
953953
)
954+
emit_with_mutating_variants("aten::heaviside : (Tensor, Tensor) -> (Tensor)")
954955
emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)")
955956
emit(
956957
"aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,44 @@ def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
298298
# ==============================================================================
299299

300300

301+
class ElementwiseHeavisideModule(torch.nn.Module):
302+
def __init__(self):
303+
super().__init__()
304+
305+
@export
306+
@annotate_args([None, ([5], torch.float32, True), ([1], torch.float32, True)])
307+
def forward(self, x, values):
308+
return torch.heaviside(x, values)
309+
310+
311+
@register_test_case(module_factory=lambda: ElementwiseHeavisideModule())
312+
def ElementwiseHeavisideModule_basic(module, tu: TestUtils):
313+
module.forward(
314+
torch.tensor([1.0, -2.0, torch.inf, torch.nan, -torch.inf]), torch.tensor([5.0])
315+
)
316+
317+
318+
class ElementwiseHeavisideIntModule(torch.nn.Module):
319+
def __init__(self):
320+
super().__init__()
321+
322+
@export
323+
@annotate_args([None, ([-1, -1], torch.int32, True), ([-1], torch.int32, True)])
324+
def forward(self, x, values):
325+
return torch.heaviside(x, values)
326+
327+
328+
@register_test_case(module_factory=lambda: ElementwiseHeavisideIntModule())
329+
def ElementwiseHeavisideIntModule_basic(module, tu: TestUtils):
330+
module.forward(
331+
tu.randint(5, 1, low=-100, high=1000).to(torch.int32),
332+
tu.randint(1, low=-100, high=1000).to(torch.int32),
333+
)
334+
335+
336+
# ==============================================================================
337+
338+
301339
class ElementwiseLtIntScalarModule(torch.nn.Module):
302340
def __init__(self):
303341
super().__init__()

0 commit comments

Comments
 (0)