19
19
#include < tvm/relax/expr.h>
20
20
21
21
namespace tvm {
22
+
23
+ RelayExpr RelayExprNode::shape () const {
24
+ if (this ->shape_ .defined ()) {
25
+ return Downcast<RelayExpr>(this ->shape_ );
26
+ }
27
+ static const Op& op = Op::Get (" relax.shape_of" );
28
+ RelayExpr self = GetRef<RelayExpr>(this );
29
+ return relay::Call (op, {self}, {}, {});
30
+ }
31
+
32
+ TVM_REGISTER_GLOBAL (" ir.RelayExprShape" )
33
+ .set_body_typed([](RelayExpr expr) {
34
+ return expr->shape ();
35
+ });
36
+
22
37
namespace relax {
23
38
24
39
using tvm::runtime::Optional;
25
40
26
-
27
41
TVM_REGISTER_NODE_TYPE (ShapeExprNode);
28
42
29
43
ShapeExpr::ShapeExpr (Array<PrimExpr> values) {
@@ -41,7 +55,7 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr")
41
55
TVM_REGISTER_NODE_TYPE (VarNode);
42
56
43
57
Var::Var (Id vid,
44
- Optional<Array<PrimExpr> > shape_annotation,
58
+ Optional<Expr > shape_annotation,
45
59
Optional<Type> type_annotation,
46
60
Span span) {
47
61
ObjectPtr<VarNode> n = make_object<VarNode>();
@@ -54,7 +68,7 @@ Var::Var(Id vid,
54
68
55
69
TVM_REGISTER_GLOBAL (" relax.Var" )
56
70
.set_body_typed([](String name_hint,
57
- Optional<Array<PrimExpr> > shape_annotation,
71
+ Optional<Expr > shape_annotation,
58
72
Optional<Type> type_annotation) {
59
73
return Var (name_hint, shape_annotation, type_annotation);
60
74
});
@@ -64,7 +78,7 @@ TVM_REGISTER_NODE_TYPE(DataflowVarNode);
64
78
65
79
TVM_REGISTER_GLOBAL (" relax.DataflowVar" )
66
80
.set_body_typed([](String name_hint,
67
- Optional<Array<PrimExpr> > shape_annotation,
81
+ Optional<Expr > shape_annotation,
68
82
Optional<Type> type_annotation) {
69
83
return DataflowVar (name_hint, shape_annotation, type_annotation);
70
84
});
0 commit comments