Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

[Pass] Canonicalizing Bindings #233

Merged
merged 14 commits into from
Sep 8, 2022
9 changes: 9 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ TVM_DLL Pass ToNonDataflow();
*/
TVM_DLL Pass CallTIRRewrite();

/*!
* \brief Simplify a Relax module by folding var bindings and match shape nodes.
* May include other forms of expression simplification in the future.
* Best used alongside constant folding and eliminating unused bindings.
*
* \return The Pass.
*/
TVM_DLL Pass CanonicalizeBindings();

/*!
* \brief Transform Relax IR to normal form: transform AST to A-normal form, and fill the
* checked_type_ and shape_ of expressions.
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,20 @@ def Normalize() -> tvm.ir.transform.Pass:
return _ffi_api.Normalize()


def CanonicalizeBindings() -> tvm.ir.transform.Pass:
"""
Canonicalizes variable definitions
(e.g., if there is y = x and z = y, it replaces uses of y and z with x).

Best combined with constant folding and the elimination of unused definitions.

Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.CanonicalizeBindings()


def ResolveGlobals() -> tvm.ir.transform.Pass:
"""Resolve global variables using string equality. This ensures all GlobalVars in the IR refer
to the correct GlobalVar of the input IRModule. An error is reported if any GlobalVar cannot be
Expand Down
2 changes: 2 additions & 0 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ Var BlockBuilderNode::EmitMatchShape(const Expr& value, const Array<PrimExpr>& p

MatchShape match_shape = MatchShape(value, pattern, var);
cur_frame->bindings.push_back(match_shape);
binding_table_[var->vid] = value;
return var;
}

Expand All @@ -629,6 +630,7 @@ Var BlockBuilderNode::EmitMatchShape(const MatchShape& binding) {
<< "EmitMatchShape can only be used for local bindings in a dataflow block.";
ICHECK(cur_frame->is_dataflow || !binding->var.as<DataflowVarNode>())
<< "cannot emit dataflow vars outside a dataflow block: " << binding->var->name_hint();
binding_table_[binding->var->vid] = binding->value;
}
cur_frame->bindings.push_back(binding);
// TODO(@altanh, @yuchen): what value should we bind? Consider
Expand Down
182 changes: 182 additions & 0 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/transform/canonicalize.cc
* \brief Pass for simplifying modules by folding var bindings and match shape nodes.
* May include other forms of simplification in the future.
* Ideally should be used before constant folding and eliminating unused bindings.
*/

#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

namespace tvm {
namespace relax {

class BindingCanonicalizer : public ExprMutator {
public:
BindingCanonicalizer() {}

Expr VisitExpr_(const VarNode* op) override {
// remap first
Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
if (!CanCanonicalizeVar(v)) {
return Downcast<Expr>(v);
}
// visit again in case we need to do a substitution in the value
return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
}

Expr VisitExpr_(const DataflowVarNode* op) override {
Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
if (!CanCanonicalizeVar(v)) {
return Downcast<Expr>(v);
}
return ExprMutator::VisitExpr_(LookupBinding(v).as<DataflowVarNode>());
}

void VisitBinding_(const VarBindingNode* binding) override {
// Unlike default visitor, preserve the checked_type_
// We may need to change the shape field in case there are substitutions
// that need to be performed within the shape computation.
Expr new_value = this->VisitExpr(binding->value);
Var new_var = this->VisitVarDef(binding->var);

auto emit = [this](VarBinding b) {
if (this->builder_->CurrentBlockIsDataFlow() && !b->var.as<DataflowVarNode>()) {
this->builder_->EmitOutput(b);
} else {
this->builder_->Emit(b);
}
};

if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
emit(GetRef<VarBinding>(binding));
return;
}

// we don't look at the new value's shape or checked type; we only consider
// if there were any substitutions performed within the original var's shape_
Var temp = WithShapeAndType(new_var, new_var->shape_, new_var->checked_type_);
if (!temp.same_as(new_var)) {
new_var = temp;
this->var_remap_[binding->var->vid] = new_var;
}

// unlike default visitor, we do not permit the var's checked_type to change
emit(VarBinding(new_var, new_value));
Copy link
Contributor

@psrivas2 psrivas2 Aug 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is my understanding correct that this block of code is not needed if we allow type refinement. So z: Object = x can be refined to z: Tensor = x. So do we still need this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am under the impression that we should respect user annotations when they appear, which is why I've done it

Copy link
Contributor

@psrivas2 psrivas2 Sep 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went back and checked the notes in August 16, 2022 meeting. Seems like this topic is under discussion. Would be great if can reach consensus on this soon.

}

void VisitBinding_(const MatchShapeNode* binding) override {
// for match shape, we need to be cleverer and allow the shape_ to change
// due to possible substitutions
Expr new_value = this->VisitExpr(binding->value);
Expr new_pattern = this->VisitExpr(ShapeExpr(binding->pattern));

Var new_var;
if (binding->var.defined()) {
Optional<Expr> new_shape;
if (new_value->checked_type_.defined() && new_value->checked_type_.as<DynTensorTypeNode>()) {
new_shape = new_pattern;
}
// visit var def visits the var's shape_ field and may perform variable substitutions,
// so we should use that shape_ if it's defined
new_var = this->VisitVarDef(binding->var);
if (new_var->shape_.defined()) {
new_shape = Downcast<Expr>(new_var->shape_);
}

// do not permit the type to change
Var temp = WithShapeAndType(new_var, new_shape, binding->var->checked_type_);
if (!temp.same_as(new_var)) {
new_var = temp;
this->var_remap_[binding->var->vid] = new_var;
}
}

// reemit old binding if nothing changes
if (new_value.same_as(binding->value) && new_pattern.same_as(binding->pattern)) {
if (!binding->var.defined() || (binding->var.defined() && new_var.same_as(binding->var))) {
builder_->EmitMatchShape(GetRef<MatchShape>(binding));
return;
}
}

builder_->EmitMatchShape(
MatchShape(new_value, Downcast<ShapeExpr>(new_pattern)->values, new_var));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this required? None of the tests check for this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests do fail if you leave this out. This is because visiting shape_ can cause variables inside the shape_ field to change. This comes up in the test cases with relax.add, since the shape_ for that is a PackedFunc call that uses variables in the program.


private:
bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2,
std::function<bool(const ObjectRef&, const ObjectRef&)> check_eq) {
// annotations differ if one is present but not the other
// or they're both present and they differ
bool both_present = obj1.defined() && obj2.defined();
bool neither_present = !obj1.defined() && !obj2.defined();
return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2));
}

bool CanCanonicalizeVar(Var v) {
Optional<Expr> value = LookupBinding(v);
// can replace only if the value is also a var
if (!value || !value.as<VarNode>()) {
return false;
}
Var parent_var = Downcast<Var>(value);

// Cases when we conservatively do not unify:
// 1. checked_type_ or shape_ of the child differs from that of the parent
// In this case, we could be overriding user annotations.
// 2. If the child is a Var and the parent is a DataflowVar.
// That could result in a DataflowVar leaving the current DataflowBlock.
bool annotations_differ =
AnnotationsDiffer(v->shape_, parent_var->shape_,
[&](const ObjectRef& shape1, const ObjectRef& shape2) {
return builder_->CanProveShapeEqual(Downcast<Expr>(shape1),
Downcast<Expr>(shape2));
}) ||
AnnotationsDiffer(v->checked_type_, parent_var->checked_type_,
[&](const ObjectRef& type1, const ObjectRef& type2) {
return tvm::StructuralEqual()(type1, type2);
});
bool var_to_dataflow = (!v.as<DataflowVarNode>() && parent_var.as<DataflowVarNode>());
return !annotations_differ && !var_to_dataflow;
}
};

Expr CanonicalizeBindings(const Expr& e) { return BindingCanonicalizer().VisitExpr(e); }

namespace transform {

Pass CanonicalizeBindings() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeBindings(f));
};
return CreateFunctionPass(pass_func, 1, "CanonicalizeBindings", {});
}

TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings);

} // namespace transform

} // namespace relax
} // namespace tvm
Loading