Skip to content

[TORCH] Add support for aten.heaviside Op #4220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sharavak
Copy link

@sharavak sharavak commented Jun 2, 2025

  • Decomposed heaviside op into Aten ops.
  • Added test cases in the e2e part.

This implementation addresses and closes #4211

@sharavak
Copy link
Author

sharavak commented Jun 2, 2025

@stellaraccident @vivekkhandelwal1 @penguin-wwy @zjgarvey @AmosLewis, I’d be grateful if any of you could take a look at this PR. Your feedback would be greatly appreciated!

Comment on lines +10997 to +11109
SmallVector<int64_t> broadcastShape;
SmallVector<Value> broadcastShapeValue;
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
broadcastShapeValue);

auto broadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype());
auto boolBroadcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type());
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
broadcastShapeValue);
auto inputBroadcasted = rewriter.create<AtenBroadcastToOp>(
loc, broadcastType, input, indexBroadcastShapeTorchList);
auto valueBroadcasted = rewriter.create<AtenBroadcastToOp>(
loc, broadcastType, value, indexBroadcastShapeTorchList);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not needed. Since you are decomposing this op into elementwise ops, the broadcasting part will be handled during Torch->Linalg lowering.

Copy link
Author

@sharavak sharavak Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivekkhandelwal1 You're right

But I ran into an issue in a specific case: when the input shape is [1, 2, 3] and the value shape is [1, 1, 1, 1], the broadcasted result shape becomes [1, 1, 2, 3].

Without explicitly broadcasting the inputs, some intermediate ops (like aten.eq.scalar or aten.isnan) end up producing tensors of shape [1, 2, 3], which causes this error:

'tensor.cast' op operand type 'tensor<1x2x3xi1>' and result type 'tensor<1x1x2x3xi1>' are cast incompatible
So to avoid this mismatch, I added explicit broadcasting to ensure all intermediate results match the final shape.

@sharavak
Copy link
Author

sharavak commented Jun 10, 2025

@vivekkhandelwal1 Thanks a lot for the feedback. I’ve updated the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TORCH] Add support for aten.heaviside
2 participants