Skip to content

Commit a21af81

Browse files
committed
addressed the comments
1 parent 1e4b934 commit a21af81

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11069,7 +11069,7 @@ class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
1106911069
} // namespace
1107011070

1107111071
namespace {
11072-
// Decomposes aten.heaviside op into
11072+
// Decomposed aten.heaviside op into
1107311073
// using aten.eq, aten.lt, aten.logical_or, aten.where
1107411074
// Heaviside(x, y) returns:
1107511075
// 0 if x < 0

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,16 +320,18 @@ def __init__(self):
320320
super().__init__()
321321

322322
@export
323-
@annotate_args([None, ([-1, -1], torch.int32, True), ([-1], torch.int32, True)])
323+
@annotate_args(
324+
[None, ([-1, -1, -1], torch.int64, True), ([-1, -1, -1, -1], torch.int64, True)]
325+
)
324326
def forward(self, x, values):
325327
return torch.heaviside(x, values)
326328

327329

328330
@register_test_case(module_factory=lambda: ElementwiseHeavisideIntModule())
329331
def ElementwiseHeavisideIntModule_basic(module, tu: TestUtils):
330332
module.forward(
331-
tu.randint(5, 1, low=-100, high=1000).to(torch.int32),
332-
tu.randint(1, low=-100, high=1000).to(torch.int32),
333+
tu.randint(1, 2, 3, low=-100, high=1000),
334+
tu.randint(1, 1, 1, 1, low=-100, high=1000),
333335
)
334336

335337

0 commit comments

Comments
 (0)