Skip to content

Commit 4ad8d64

Browse files
YuchenJinslyubomirsky
authored andcommitted
[Unity][Pass] Canonicalize Bindings (#14079)
It may be useful for some passes to collapse chains of definitions, particularly after other compiler transformations that may reduce or simplify some expressions. This pass will take chains of definitions and replace references to later definitions to the original one. It works by checking `LookupBinding` for each var use-site and replacing the var with its definition if the definition was another var. Additionally, `MatchCast` bindings where the LHS and the RHS are guaranteed to match at compile time are canonicalized into ordinary `VarBinding`s. Example: ```python y = x z = y w = z o = w p = o ``` Will be replaced with ```python y = x z = x w = x o = x p = x ``` Original PR: tlc-pack/relax#233 Co-authored-by: Steven S. Lyubomirsky <[email protected]>
1 parent e8a0c4d commit 4ad8d64

File tree

4 files changed

+382
-0
lines changed

4 files changed

+382
-0
lines changed

include/tvm/relax/transform.h

+9
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,15 @@ TVM_DLL Pass AttachGlobalSymbol();
127127
*/
128128
TVM_DLL Pass Normalize();
129129

130+
/*!
131+
* \brief Simplify a Relax module by folding var bindings and match shape nodes.
132+
* May include other forms of expression simplification in the future.
133+
* Best used alongside constant folding and eliminating unused bindings.
134+
*
135+
* \return The Pass.
136+
*/
137+
TVM_DLL Pass CanonicalizeBindings();
138+
130139
/*!
131140
* \brief Bind params of function of the module to constant tensors.
132141
*

python/tvm/relax/transform/transform.py

+14
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,20 @@ def Normalize() -> tvm.ir.transform.Pass:
8080
return _ffi_api.Normalize() # type: ignore
8181

8282

83+
def CanonicalizeBindings() -> tvm.ir.transform.Pass:
84+
"""
85+
Canonicalizes variable definitions
86+
(e.g., if there is y = x and z = y, it replaces uses of y and z with x).
87+
88+
Best combined with constant folding and the elimination of unused definitions.
89+
90+
Returns
91+
-------
92+
ret: tvm.ir.transform.Pass
93+
"""
94+
return _ffi_api.CanonicalizeBindings() # type: ignore
95+
96+
8397
def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
8498
"""Convert all reshape-like call_tir to VM reshape operator call.
8599
The VM reshape operator calls will be further lowered to a CreateView
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relax/transform/canonicalize_bindings.cc
22+
* \brief Pass for simplifying modules by folding var bindings and match shape nodes.
23+
* May include other forms of simplification in the future.
24+
* Ideally should be used before constant folding and eliminating unused bindings.
25+
*/
26+
27+
#include <tvm/relax/expr.h>
28+
#include <tvm/relax/expr_functor.h>
29+
#include <tvm/relax/struct_info.h>
30+
#include <tvm/relax/transform.h>
31+
32+
namespace tvm {
33+
namespace relax {
34+
35+
class BindingCanonicalizer : public ExprMutator {
36+
public:
37+
BindingCanonicalizer() {}
38+
39+
Expr VisitExpr_(const VarNode* op) override {
40+
// remap first
41+
Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
42+
if (!CanCanonicalizeVar(v)) {
43+
return Downcast<Expr>(v);
44+
}
45+
// visit again in case we need to do a substitution in the value
46+
return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
47+
}
48+
49+
Expr VisitExpr_(const DataflowVarNode* op) override {
50+
Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
51+
if (!CanCanonicalizeVar(v)) {
52+
return Downcast<Expr>(v);
53+
}
54+
return ExprMutator::VisitExpr_(LookupBinding(v).as<DataflowVarNode>());
55+
}
56+
57+
void VisitBinding_(const VarBindingNode* binding) override {
58+
// Unlike default visitor, we do not permit the checked type to change
59+
// if the new value's checked type is different (this preserves user annotations)
60+
Expr new_value = this->VisitExpr(binding->value);
61+
Var new_var = this->VisitVarDef(binding->var);
62+
63+
if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) {
64+
this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
65+
return;
66+
}
67+
68+
this->builder_->EmitNormalized(VarBinding(new_var, new_value));
69+
}
70+
71+
void VisitBinding_(const MatchCastNode* binding) override {
72+
// If we have a trivial shape check (the shape_ of LHS and RHS is the same),
73+
// we can canonicalize to a var binding
74+
Expr new_value = this->VisitExpr(binding->value);
75+
76+
// if the LHS and RHS have the same struct info, we canonicalize to a var binding instead
77+
if (StructuralEqual()(binding->struct_info, GetStructInfo(new_value))) {
78+
builder_->EmitNormalized(VarBinding(binding->var, new_value));
79+
} else if (new_value.same_as(binding->value)) {
80+
builder_->EmitNormalized(GetRef<MatchCast>(binding));
81+
} else {
82+
builder_->EmitNormalized(MatchCast(binding->var, new_value, binding->struct_info));
83+
}
84+
}
85+
86+
private:
87+
bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2,
88+
std::function<bool(const ObjectRef&, const ObjectRef&)> check_eq) {
89+
// annotations differ if one is present but not the other
90+
// or they're both present and they differ
91+
bool both_present = obj1.defined() && obj2.defined();
92+
bool neither_present = !obj1.defined() && !obj2.defined();
93+
return !(both_present || neither_present) || (both_present && !check_eq(obj1, obj2));
94+
}
95+
96+
bool CanCanonicalizeVar(Var v) {
97+
Optional<Expr> value = LookupBinding(v);
98+
// can replace only if the value is also a var
99+
if (!value || !value.as<VarNode>()) {
100+
return false;
101+
}
102+
Var parent_var = Downcast<Var>(value);
103+
104+
// Cases when we conservatively do not unify:
105+
// 1. checked_type_ or shape_ of the child differs from that of the parent
106+
// In this case, we could be overriding user annotations.
107+
// 2. If the child is a Var and the parent is a DataflowVar.
108+
// That could result in a DataflowVar leaving the current DataflowBlock.
109+
bool annotations_differ = AnnotationsDiffer(v->struct_info_, parent_var->struct_info_,
110+
[&](const ObjectRef& lhs, const ObjectRef& rhs) {
111+
return tvm::StructuralEqual()(lhs, rhs);
112+
});
113+
bool var_to_dataflow = (!v.as<DataflowVarNode>() && parent_var.as<DataflowVarNode>());
114+
return !annotations_differ && !var_to_dataflow;
115+
}
116+
};
117+
118+
Expr CanonicalizeBindings(const Expr& e) { return BindingCanonicalizer().VisitExpr(e); }
119+
120+
namespace transform {
121+
122+
Pass CanonicalizeBindings() {
123+
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
124+
[=](Function f, IRModule m, PassContext pc) {
125+
return Downcast<Function>(CanonicalizeBindings(f));
126+
};
127+
return CreateFunctionPass(pass_func, 1, "CanonicalizeBindings", {});
128+
}
129+
130+
TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings);
131+
132+
} // namespace transform
133+
134+
} // namespace relax
135+
} // namespace tvm
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import tvm
19+
import tvm.script
20+
import tvm.testing
21+
import pytest
22+
from tvm import relax
23+
from tvm.ir.base import assert_structural_equal
24+
from tvm.script import relax as R, tir as T
25+
26+
27+
def test_simple_assignments():
28+
@tvm.script.ir_module
29+
class TestChainAssignments:
30+
@R.function
31+
def main(x: R.Tensor):
32+
y = x
33+
z = y
34+
q = z
35+
p = q
36+
o = p
37+
return o
38+
39+
# a little annoying to have these unused bindings around
40+
# but they can be eliminated in a separate pass
41+
@tvm.script.ir_module
42+
class Expected:
43+
@R.function
44+
def main(x: R.Tensor):
45+
y = x
46+
z = x
47+
q = x
48+
p = x
49+
o = x
50+
return x
51+
52+
new_mod = relax.transform.CanonicalizeBindings()(TestChainAssignments)
53+
assert_structural_equal(new_mod, Expected)
54+
55+
56+
def test_dataflow_block():
57+
@tvm.script.ir_module
58+
class TestDataflowAssignments:
59+
@R.function
60+
def main(x: R.Tensor):
61+
with R.dataflow():
62+
y = R.const(1)
63+
z = y
64+
o = z
65+
p = o
66+
m = p
67+
n = m
68+
R.output(n)
69+
return n
70+
71+
# a little annoying to have these unused bindings around
72+
# but they can be eliminated in a separate pass
73+
@tvm.script.ir_module
74+
class Expected:
75+
@R.function
76+
def main(x: R.Tensor):
77+
with R.dataflow():
78+
y = R.const(1)
79+
z = y
80+
o = y
81+
p = y
82+
m = y
83+
# we can't get rid of n because it leaves the block
84+
n = y
85+
R.output(n)
86+
return n
87+
88+
new_mod = relax.transform.CanonicalizeBindings()(TestDataflowAssignments)
89+
assert_structural_equal(new_mod, Expected)
90+
91+
92+
def test_ops():
93+
@tvm.script.ir_module
94+
class TestOps:
95+
@R.function
96+
def main(x: R.Tensor, y: R.Tensor):
97+
w = y
98+
q = x
99+
z = R.add(w, q)
100+
return R.add(q, z)
101+
102+
@tvm.script.ir_module
103+
class Expected:
104+
@R.function
105+
def main(x: R.Tensor, y: R.Tensor):
106+
w = y
107+
q = x
108+
z = R.add(y, x)
109+
return R.add(x, z)
110+
111+
new_mod = relax.transform.CanonicalizeBindings()(TestOps)
112+
assert_structural_equal(new_mod, Expected)
113+
114+
115+
@pytest.mark.xfail(reason="The lhs and rhs of an assignment should have the same struct info.")
116+
def test_casting():
117+
@tvm.script.ir_module
118+
class TestCasting:
119+
@R.function
120+
def main(x: R.Tensor) -> R.Object:
121+
y = x
122+
# z will be treated as object type even though it's a tensor
123+
z: R.Object = y
124+
return z
125+
126+
@tvm.script.ir_module
127+
class Expected:
128+
@R.function
129+
def main(x: R.Tensor) -> R.Object:
130+
y = x
131+
# Cannot unify because the cast indicates user intent
132+
z: R.Object = x
133+
return z
134+
135+
new_mod = relax.transform.CanonicalizeBindings()(TestCasting)
136+
assert_structural_equal(new_mod, Expected)
137+
138+
139+
def test_match_cast():
140+
@tvm.script.ir_module
141+
class TestMatchCast:
142+
@R.function
143+
def main(x: R.Tensor):
144+
q = x
145+
m, n = T.var("int64"), T.var("int64")
146+
z = R.match_cast(q, R.Tensor((m, n)))
147+
w = z
148+
return w
149+
150+
@tvm.script.ir_module
151+
class Expected:
152+
@R.function
153+
def main(x: R.Tensor):
154+
q = x
155+
# can't get rid of z because its shape_ is different from x's
156+
m, n = T.var("int64"), T.var("int64")
157+
z = R.match_cast(x, R.Tensor((m, n)))
158+
w = z
159+
return z
160+
161+
new_mod = relax.transform.CanonicalizeBindings()(TestMatchCast)
162+
assert_structural_equal(new_mod, Expected)
163+
164+
165+
def test_same_shape():
166+
@tvm.script.ir_module
167+
class TestSameShape:
168+
@R.function
169+
def main(x: R.Tensor(("m", "n"), "float32")):
170+
m, n = T.var("int64"), T.var("int64")
171+
y = x
172+
# trivial check
173+
z = R.match_cast(x, R.Tensor((m, n), "float32"))
174+
w = z
175+
q = R.add(w, y)
176+
return R.add(q, w)
177+
178+
@tvm.script.ir_module
179+
class Expected:
180+
@R.function
181+
def main(x: R.Tensor(("m", "n"), "float32")):
182+
m, n = T.var("int64"), T.var("int64")
183+
y = x
184+
# canonicalized into a var binding
185+
z = x
186+
w = x
187+
q = R.add(x, x)
188+
return R.add(q, x)
189+
190+
new_mod = relax.transform.CanonicalizeBindings()(TestSameShape)
191+
assert_structural_equal(new_mod, Expected)
192+
193+
194+
def test_change_shape():
195+
@tvm.script.ir_module
196+
class TestChangeShape:
197+
@R.function
198+
def main(x: R.Tensor(("m", "n"))):
199+
y = x
200+
# not trivial: introduces new shape vars
201+
o, p = T.var("int64"), T.var("int64")
202+
z = R.match_cast(x, R.Tensor((o, p)))
203+
w = z
204+
q = R.add(w, y)
205+
return R.add(q, w)
206+
207+
@tvm.script.ir_module
208+
class Expected:
209+
@R.function
210+
def main(x: R.Tensor(("m", "n"))):
211+
y = x
212+
o, p = T.var("int64"), T.var("int64")
213+
z = R.match_cast(x, R.Tensor((o, p)))
214+
w = z
215+
# the shape_ field on q will need to be updated
216+
q = R.add(z, x)
217+
return R.add(q, z)
218+
219+
new_mod = relax.transform.CanonicalizeBindings()(TestChangeShape)
220+
assert_structural_equal(new_mod, Expected)
221+
222+
223+
if __name__ == "__main__":
224+
tvm.testing.main()

0 commit comments

Comments
 (0)