Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Update AST and Shape() implementation #5

Merged
merged 4 commits into from
Aug 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,13 @@ class RelayExprNode : public BaseExprNode {
mutable Type checked_type_ = Type(nullptr);

/*!
* \brief Stores the result of static shape analysis.
* \brief Stores the result of static shape analysis. It must be a RelayExpr
* and ObjectRef is used here to avoid cyclic typing.
*
* \note The value will be optional if a static shape can not be inferred.
* use .shape() instead to acesss an always defined shape expression.
*/
Optional<Array<PrimExpr>> shape_ = Optional<Array<PrimExpr>>();
mutable Optional<ObjectRef> shape_ = Optional<ObjectRef>();

/*!
* \return The checked_type
Expand All @@ -168,7 +169,7 @@ class RelayExprNode : public BaseExprNode {
*
* Only valid when the expression's type is a Tensor.
*/
inline RelayExpr shape() const;
RelayExpr shape() const;

/*!
* \brief Check if the inferred(checked) type of the Expr
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
namespace tvm {
namespace relax {

using Expr = RelayExpr;
using ExprNode = RelayExprNode;
using relay::Id;
using relay::Call;
using relay::Tuple;
using relay::TupleGetItem;
using ExprNode = RelayExprNode;
using Expr = RelayExpr;

/*! \brief A shape expression which allows users to construct a shape containing PrimExpr.
*/
Expand Down Expand Up @@ -121,13 +121,13 @@ class VarNode : public ExprNode {
class Var : public Expr {
public:
TVM_DLL Var(String name_hint,
runtime::Optional<Array<PrimExpr>> shape_annotation,
runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation,
Span span = Span())
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}

TVM_DLL Var(Id vid,
runtime::Optional<Array<PrimExpr>> shape_annotation,
runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation,
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def checked_type(self):
raise ValueError("The type checker has not populated" " the checked_type for this node")
return ret

@property
def shape(self):
"""Get the shape of tvm.relay.Expr.

Returns
-------
shape : tvm.ir.RelayExpr
The expression that represents the shape.
"""
return _ffi_api.RelayExprShape(self)



@tvm._ffi.register_object("GlobalVar")
class GlobalVar(RelayExpr):
Expand Down
18 changes: 17 additions & 1 deletion python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,31 @@ class ShapeExpr(Expr):
def __init__(self, values: List[PrimExpr]) -> None:
self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values)

def __getitem__(self, index):
if index >= len(self):
raise IndexError("Tuple index out of range")
return self.values[index]

def __len__(self):
return len(self.values)

def make_shape(shape: List[PrimExpr]) -> ShapeExpr:
if isinstance(shape, (list, tuple)):
return ShapeExpr(shape)
else:
raise ValueError


@tvm._ffi.register_object("relax.expr.Var")
class Var(Expr):
id: Id
type_annotation: Optional[Type]

def __init__(self, name_hint: str,
shape_annotation: Optional[List[Type]] = None,
shape_annotation: Optional[Expr] = None,
type_annotation: Optional[Type] = None) -> None:
if shape_annotation is not None:
shape_annotation = make_shape(shape_annotation)
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint,
shape_annotation,
type_annotation)
Expand Down
22 changes: 18 additions & 4 deletions src/relax/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,25 @@
#include <tvm/relax/expr.h>

namespace tvm {

RelayExpr RelayExprNode::shape() const {
if (this->shape_.defined()) {
return Downcast<RelayExpr>(this->shape_);
}
static const Op& op = Op::Get("relax.shape_of");
RelayExpr self = GetRef<RelayExpr>(this);
return relay::Call(op, {self}, {}, {});
}

TVM_REGISTER_GLOBAL("ir.RelayExprShape")
.set_body_typed([](RelayExpr expr) {
return expr->shape();
});

namespace relax {

using tvm::runtime::Optional;


TVM_REGISTER_NODE_TYPE(ShapeExprNode);

ShapeExpr::ShapeExpr(Array<PrimExpr> values) {
Expand All @@ -41,7 +55,7 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr")
TVM_REGISTER_NODE_TYPE(VarNode);

Var::Var(Id vid,
Optional<Array<PrimExpr>> shape_annotation,
Optional<Expr> shape_annotation,
Optional<Type> type_annotation,
Span span) {
ObjectPtr<VarNode> n = make_object<VarNode>();
Expand All @@ -54,7 +68,7 @@ Var::Var(Id vid,

TVM_REGISTER_GLOBAL("relax.Var")
.set_body_typed([](String name_hint,
Optional<Array<PrimExpr>> shape_annotation,
Optional<Expr> shape_annotation,
Optional<Type> type_annotation) {
return Var(name_hint, shape_annotation, type_annotation);
});
Expand All @@ -64,7 +78,7 @@ TVM_REGISTER_NODE_TYPE(DataflowVarNode);

TVM_REGISTER_GLOBAL("relax.DataflowVar")
.set_body_typed([](String name_hint,
Optional<Array<PrimExpr>> shape_annotation,
Optional<Expr> shape_annotation,
Optional<Type> type_annotation) {
return DataflowVar(name_hint, shape_annotation, type_annotation);
});
Expand Down
27 changes: 21 additions & 6 deletions src/relax/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,34 @@
namespace tvm {
namespace relax {

// call_dps

RELAY_REGISTER_OP("relax.call_dps")
.set_num_inputs(3)
.add_argument("shape", "ShapeExpr", "The output shape.")
.add_argument("func", "Expr", "The destination-passing-style function.")
.add_argument("args", "Tuple", "The input arguments.");

Expr MakeCallDPS(ShapeExpr shape, Expr func, Tuple args) {
static const Op& op = Op::Get("call_dps");
static const Op& op = Op::Get("relax.call_dps");
return Call(op, {shape, func, args}, {}, {});
}

TVM_REGISTER_GLOBAL("relax.op.call_dps")
.set_body_typed(MakeCallDPS);

RELAY_REGISTER_OP("call_dps")
.set_num_inputs(3)
.add_argument("shape", "ShapeExpr", "The output shape.")
.add_argument("func", "Expr", "The destination-passing-style function.")
.add_argument("args", "Tuple", "The input arguments.");
// shape_of

RELAY_REGISTER_OP("relax.shape_of")
.set_num_inputs(1)
.add_argument("input", "Expr", "The input expression");

Expr MakeShapeOf(Expr expr) {
static const Op& op = Op::Get("relax.shape_of");
return Call(op, {expr}, {}, {});
}

TVM_REGISTER_GLOBAL("relax.op.shape_of")
.set_body_typed(MakeShapeOf);
} // namespace relax
} // namespace tvm
14 changes: 14 additions & 0 deletions tests/python/relax/test_ast.py → tests/python/relax/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def test_func():
assert func.name.name_hint == "func"


def test_shape_of():
v0 = rx.Var("v0")
s0 = v0.shape
assert isinstance(s0, tvm.relay.Call)
assert s0.op.name == "relax.shape_of"

shape_anno = [96, 54]
v1 = rx.Var("v1", shape_anno)
s1 = v1.shape
for x, y in zip(shape_anno, s1):
assert x == y


if __name__ == "__main__":
test_var()
test_dataflow_var()
Expand All @@ -123,3 +136,4 @@ def test_func():
test_seq_expr()
test_shape_expr()
test_func()
test_shape_of()