Skip to content

Commit 7b10589

Browse files
ZihengJiangyongwww
authored andcommitted
Update AST and Shape() implementation (tlc-pack#5)
* Update AST. * ShapeOf. * ShapeOf. * Address comment.
1 parent 81c1004 commit 7b10589

File tree

7 files changed

+90
-18
lines changed

7 files changed

+90
-18
lines changed

include/tvm/ir/expr.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,13 @@ class RelayExprNode : public BaseExprNode {
154154
mutable Type checked_type_ = Type(nullptr);
155155

156156
/*!
157-
* \brief Stores the result of static shape analysis.
157+
* \brief Stores the result of static shape analysis. It must be a RelayExpr
158+
* and ObjectRef is used here to avoid cyclic typing.
158159
*
159160
* \note The value will be optional if a static shape can not be inferred.
160161
* use .shape() instead to acesss an always defined shape expression.
161162
*/
162-
Optional<Array<PrimExpr>> shape_ = Optional<Array<PrimExpr>>();
163+
mutable Optional<ObjectRef> shape_ = Optional<ObjectRef>();
163164

164165
/*!
165166
* \return The checked_type
@@ -171,7 +172,7 @@ class RelayExprNode : public BaseExprNode {
171172
*
172173
* Only valid when the expression's type is a Tensor.
173174
*/
174-
inline RelayExpr shape() const;
175+
RelayExpr shape() const;
175176

176177
/*!
177178
* \brief Check if the inferred(checked) type of the Expr

include/tvm/relax/expr.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
namespace tvm {
3232
namespace relax {
3333

34+
using Expr = RelayExpr;
35+
using ExprNode = RelayExprNode;
3436
using relay::Id;
3537
using relay::Call;
3638
using relay::Tuple;
3739
using relay::TupleGetItem;
38-
using ExprNode = RelayExprNode;
39-
using Expr = RelayExpr;
4040

4141
/*! \brief A shape expression which allows users to construct a shape containing PrimExpr.
4242
*/
@@ -121,13 +121,13 @@ class VarNode : public ExprNode {
121121
class Var : public Expr {
122122
public:
123123
TVM_DLL Var(String name_hint,
124-
runtime::Optional<Array<PrimExpr>> shape_annotation,
124+
runtime::Optional<Expr> shape_annotation,
125125
runtime::Optional<Type> type_annotation,
126126
Span span = Span())
127127
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}
128128

129129
TVM_DLL Var(Id vid,
130-
runtime::Optional<Array<PrimExpr>> shape_annotation,
130+
runtime::Optional<Expr> shape_annotation,
131131
runtime::Optional<Type> type_annotation,
132132
Span span = Span());
133133
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);

python/tvm/ir/expr.py

+12
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,18 @@ def checked_type(self):
5050
raise ValueError("The type checker has not populated" " the checked_type for this node")
5151
return ret
5252

53+
@property
54+
def shape(self):
55+
"""Get the shape of tvm.relay.Expr.
56+
57+
Returns
58+
-------
59+
shape : tvm.ir.RelayExpr
60+
The expression that represents the shape.
61+
"""
62+
return _ffi_api.RelayExprShape(self)
63+
64+
5365

5466
@tvm._ffi.register_object("GlobalVar")
5567
class GlobalVar(RelayExpr):

python/tvm/relax/expr.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,31 @@ class ShapeExpr(Expr):
3737
def __init__(self, values: List[PrimExpr]) -> None:
3838
self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values)
3939

40+
def __getitem__(self, index):
41+
if index >= len(self):
42+
raise IndexError("Tuple index out of range")
43+
return self.values[index]
44+
45+
def __len__(self):
46+
return len(self.values)
47+
48+
def make_shape(shape: List[PrimExpr]) -> ShapeExpr:
49+
if isinstance(shape, (list, tuple)):
50+
return ShapeExpr(shape)
51+
else:
52+
raise ValueError
53+
4054

4155
@tvm._ffi.register_object("relax.expr.Var")
4256
class Var(Expr):
4357
id: Id
4458
type_annotation: Optional[Type]
4559

4660
def __init__(self, name_hint: str,
47-
shape_annotation: Optional[List[Type]] = None,
61+
shape_annotation: Optional[Expr] = None,
4862
type_annotation: Optional[Type] = None) -> None:
63+
if shape_annotation is not None:
64+
shape_annotation = make_shape(shape_annotation)
4965
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint,
5066
shape_annotation,
5167
type_annotation)

src/relax/expr.cc

+18-4
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,25 @@
1919
#include <tvm/relax/expr.h>
2020

2121
namespace tvm {
22+
23+
RelayExpr RelayExprNode::shape() const {
24+
if (this->shape_.defined()) {
25+
return Downcast<RelayExpr>(this->shape_);
26+
}
27+
static const Op& op = Op::Get("relax.shape_of");
28+
RelayExpr self = GetRef<RelayExpr>(this);
29+
return relay::Call(op, {self}, {}, {});
30+
}
31+
32+
TVM_REGISTER_GLOBAL("ir.RelayExprShape")
33+
.set_body_typed([](RelayExpr expr) {
34+
return expr->shape();
35+
});
36+
2237
namespace relax {
2338

2439
using tvm::runtime::Optional;
2540

26-
2741
TVM_REGISTER_NODE_TYPE(ShapeExprNode);
2842

2943
ShapeExpr::ShapeExpr(Array<PrimExpr> values) {
@@ -41,7 +55,7 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr")
4155
TVM_REGISTER_NODE_TYPE(VarNode);
4256

4357
Var::Var(Id vid,
44-
Optional<Array<PrimExpr>> shape_annotation,
58+
Optional<Expr> shape_annotation,
4559
Optional<Type> type_annotation,
4660
Span span) {
4761
ObjectPtr<VarNode> n = make_object<VarNode>();
@@ -54,7 +68,7 @@ Var::Var(Id vid,
5468

5569
TVM_REGISTER_GLOBAL("relax.Var")
5670
.set_body_typed([](String name_hint,
57-
Optional<Array<PrimExpr>> shape_annotation,
71+
Optional<Expr> shape_annotation,
5872
Optional<Type> type_annotation) {
5973
return Var(name_hint, shape_annotation, type_annotation);
6074
});
@@ -64,7 +78,7 @@ TVM_REGISTER_NODE_TYPE(DataflowVarNode);
6478

6579
TVM_REGISTER_GLOBAL("relax.DataflowVar")
6680
.set_body_typed([](String name_hint,
67-
Optional<Array<PrimExpr>> shape_annotation,
81+
Optional<Expr> shape_annotation,
6882
Optional<Type> type_annotation) {
6983
return DataflowVar(name_hint, shape_annotation, type_annotation);
7084
});

src/relax/op.cc

+21-6
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,34 @@
2222
namespace tvm {
2323
namespace relax {
2424

25+
// call_dps
26+
27+
RELAY_REGISTER_OP("relax.call_dps")
28+
.set_num_inputs(3)
29+
.add_argument("shape", "ShapeExpr", "The output shape.")
30+
.add_argument("func", "Expr", "The destination-passing-style function.")
31+
.add_argument("args", "Tuple", "The input arguments.");
32+
2533
Expr MakeCallDPS(ShapeExpr shape, Expr func, Tuple args) {
26-
static const Op& op = Op::Get("call_dps");
34+
static const Op& op = Op::Get("relax.call_dps");
2735
return Call(op, {shape, func, args}, {}, {});
2836
}
2937

3038
TVM_REGISTER_GLOBAL("relax.op.call_dps")
3139
.set_body_typed(MakeCallDPS);
3240

33-
RELAY_REGISTER_OP("call_dps")
34-
.set_num_inputs(3)
35-
.add_argument("shape", "ShapeExpr", "The output shape.")
36-
.add_argument("func", "Expr", "The destination-passing-style function.")
37-
.add_argument("args", "Tuple", "The input arguments.");
41+
// shape_of
42+
43+
RELAY_REGISTER_OP("relax.shape_of")
44+
.set_num_inputs(1)
45+
.add_argument("input", "Expr", "The input expression");
46+
47+
Expr MakeShapeOf(Expr expr) {
48+
static const Op& op = Op::Get("relax.shape_of");
49+
return Call(op, {expr}, {}, {});
50+
}
3851

52+
TVM_REGISTER_GLOBAL("relax.op.shape_of")
53+
.set_body_typed(MakeShapeOf);
3954
} // namespace relax
4055
} // namespace tvm

tests/python/relax/test_ast.py renamed to tests/python/relax/test_expr.py

+14
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,19 @@ def test_func():
113113
assert func.name.name_hint == "func"
114114

115115

116+
def test_shape_of():
117+
v0 = rx.Var("v0")
118+
s0 = v0.shape
119+
assert isinstance(s0, tvm.relay.Call)
120+
assert s0.op.name == "relax.shape_of"
121+
122+
shape_anno = [96, 54]
123+
v1 = rx.Var("v1", shape_anno)
124+
s1 = v1.shape
125+
for x, y in zip(shape_anno, s1):
126+
assert x == y
127+
128+
116129
if __name__ == "__main__":
117130
test_var()
118131
test_dataflow_var()
@@ -123,3 +136,4 @@ def test_func():
123136
test_seq_expr()
124137
test_shape_expr()
125138
test_func()
139+
test_shape_of()

0 commit comments

Comments
 (0)