Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Commit 1c73696

Browse files
authored
[TVMScript] Switch to the new parser (#276)
* [TVMScript] Support cross-function call for relax function This PR adds support for cross-function call for relax function, by declaring a function signature (i.e. an empty function that contains params and return type/shape but w/o body.) However, the PR meets the issue of block_builder shape deduction, which does not use function `ret_shape` to infer the shape of GlobalVar Calls.
1 parent a443b8d commit 1c73696

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1742
-1386
lines changed

apps/relax_examples/nn_module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252

5353
# get and print the IRmodule being built
5454
mod = builder.get()
55-
print(R.parser.astext(mod))
55+
mod.show()
5656

5757
# build the IRModule and create relax vm
5858
target = tvm.target.Target("llvm", host="llvm")

apps/relax_examples/resnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
relax_mod = relay_translator.from_relay(relay_mod["main"], target)
3434

3535
# print the ResNet IRmodule got translated
36-
print(R.parser.astext(relax_mod))
36+
relax_mod.show()
3737

3838
# build the IRModule and create relax vm
3939
ex = relax.vm.build(relax_mod, target)

include/tvm/script/ir_builder/ir/ir.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@ TVM_DLL IRModuleFrame IRModule();
4141
* \brief Declare a Function without given the specific function implementation.
4242
* \note It is usually used in cross-function call. And we can specify the function by `DefFunction`
4343
* \param func_name The function unique name.
44+
* \param func_signature A Function w/o body, which used to specify the function signature
45+
* (i.e. func params and func return type/shape).
4446
* \return The corresponding GlobalVar.
4547
*/
46-
TVM_DLL GlobalVar DeclFunction(const String& func_name);
48+
TVM_DLL GlobalVar DeclFunction(const String& func_name,
49+
const Optional<BaseFunc>& func_signature = NullOpt);
4750

4851
/*!
4952
* \brief Define the function which is declared before.

include/tvm/script/ir_builder/relax/frame.h

+16-4
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class SeqExprFrameNode : public RelaxFrameNode {
6767
TVM_DECLARE_BASE_OBJECT_INFO(SeqExprFrameNode, RelaxFrameNode);
6868

6969
public:
70+
void EnterWithScope() override;
7071
void ExitWithScope() override;
7172
};
7273

@@ -94,6 +95,11 @@ class FunctionFrameNode : public SeqExprFrameNode {
9495
* If the `ret_type` is not None, check the deduced type is a base type of the given one.
9596
*/
9697
Optional<Type> ret_type;
98+
/*!
99+
* \brief The function return shape.
100+
* \sa ret_type
101+
*/
102+
Optional<tvm::relax::Expr> ret_shape;
97103
/*! \brief The function attributes. */
98104
Map<String, ObjectRef> attrs;
99105
/*! \brief The block builder to create Relax function. */
@@ -130,17 +136,23 @@ class BlockFrameNode : public RelaxFrameNode {
130136
/*! \brief The variables emitted in this block. */
131137
Array<tvm::relax::Var> emitted_vars;
132138
/*!
133-
* \brief (Only used for a dataflow block.) A boolean indicating if the dataflow block is ended of
134-
* construction. If it is true, any new binding trying to be emitted into this block will cause an
135-
* error.
139+
* \brief A boolean indicating if the dataflow block is ended of construction.
140+
* If it is true, any new binding trying to be emitted into this block will cause an error.
141+
* \note Only used for a dataflow block.
136142
*/
137143
bool block_ended;
144+
/*!
145+
* \brief The output vars of the dataflow block.
146+
* \note Only used for a dataflow block.
147+
*/
148+
Array<tvm::relax::Var> output_vars;
138149

139150
void VisitAttrs(tvm::AttrVisitor* v) {
140151
RelaxFrameNode::VisitAttrs(v);
141152
v->Visit("is_dataflow", &is_dataflow);
142153
v->Visit("emitted_vars", &emitted_vars);
143-
v->Visit("block_ended", &block_ended);
154+
v->Visit("output_vars", &output_vars);
155+
// `block_ended` is not visited.
144156
}
145157

146158
static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame";

include/tvm/script/ir_builder/relax/ir.h

+10-10
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ TVM_DLL FunctionFrame Function();
7878
* \param shape The shape of the parameter.
7979
* \return The created function parameter var.
8080
*/
81-
TVM_DLL tvm::relax::Var Arg(const String& name, const Type& type,
82-
const tvm::relax::ShapeExpr& shape);
81+
TVM_DLL tvm::relax::Var Arg(const String& name, const Type& type, const tvm::relax::Expr& shape);
8382

8483
/*!
8584
* \brief Specify the name of the last function frame.
@@ -99,6 +98,12 @@ TVM_DLL void FuncAttrs(Map<String, ObjectRef> attrs);
9998
*/
10099
TVM_DLL void FuncRetType(tvm::Type ret_type);
101100

101+
/*!
102+
* \brief Specify the return shape of the last function frame.
103+
* \param ret_shape The return shape.
104+
*/
105+
TVM_DLL void FuncRetShape(tvm::relax::Expr ret_shape);
106+
102107
/*!
103108
* \brief Specify the return value of the last function frame.
104109
* \param value The return value.
@@ -130,25 +135,20 @@ TVM_DLL void DataflowBlockOutput(const Array<tvm::relax::Var>& vars);
130135
/*!
131136
* \brief Emit a binding to the last binding block frame.
132137
* \param value The right side value of the bindings to be emitted.
133-
* \param is_dataflow_var A boolean indicating if the emitted binding variable is a dataflow
134-
* variable.
135138
* \return The left side var of the emitted binding.
136139
*/
137-
TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value, bool is_dataflow_var);
140+
TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value);
138141

139142
/*!
140143
* \brief Emit a match_shape binding to the last binding block frame.
141144
* \param value The value of the MatchShape to be emitted.
142145
* \param pattern The pattern of the MatchShape to be emitted.
143146
* \param emit_var A boolean indicating if the MatchShape contains the emitted variable.
144-
* \param is_dataflow_var A boolean indicating if the emitted variable is a dataflow variable when
145-
* `emit_var` is true. When `emit_var` is false, the value of this flag will be ignored.
146147
* \return The emitted var if `emit_var` is true. Otherwise, return `NullOpt`.
147148
*/
148149
TVM_DLL Optional<tvm::relax::Var> EmitMatchShape(const tvm::relax::Expr& value, //
149150
const Array<PrimExpr>& pattern, //
150-
bool emit_var, //
151-
bool is_dataflow_var);
151+
bool emit_var);
152152

153153
///////////////////////////// Type Deduce //////////////////////////////
154154

@@ -161,7 +161,7 @@ TVM_DLL Optional<tvm::relax::Var> EmitMatchShape(const tvm::relax::Expr& value,
161161
* And we annotate to the var with more detailed type.
162162
*/
163163
TVM_DLL void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type,
164-
const Optional<tvm::relax::ShapeExpr>& anno_shape);
164+
const Optional<tvm::relax::Expr>& anno_shape);
165165

166166
///////////////////////////// If Then Else /////////////////////////////
167167

python/tvm/ir/function.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""Function defintiions."""
18-
from __future__ import annotations
17+
"""Function definitions."""
1918
from typing import Union, Dict
2019
from enum import IntEnum
2120
import tvm.runtime
@@ -42,7 +41,7 @@ def attrs(self):
4241
"""Return the attrs member of the function."""
4342
return _ffi_api.BaseFunc_Attrs(self)
4443

45-
def with_attr(self, attr_key_or_dict, attr_value=None) -> BaseFunc:
44+
def with_attr(self, attr_key_or_dict, attr_value=None) -> "BaseFunc":
4645
"""Create a new copy of the function and update the attribute.
4746
4847
Parameters
@@ -71,7 +70,7 @@ def with_attr(self, attr_key_or_dict, attr_value=None) -> BaseFunc:
7170
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value)
7271
)
7372

74-
def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> BaseFunc:
73+
def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> "BaseFunc":
7574
"""Copy the IRModule and add the given attribute map to it.
7675
Parameters
7776
----------
@@ -87,7 +86,7 @@ def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> BaseFunc:
8786

8887
return _ffi_api.BaseFuncWithAttrs(self, attr_map)
8988

90-
def without_attr(self, attr_key: str) -> BaseFunc:
89+
def without_attr(self, attr_key: str) -> "BaseFunc":
9190
"""Create a new copy of the function with an attribute without provided key.
9291
9392
Parameters

python/tvm/ir/module.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""IRModule that holds the functions and type definitions."""
18-
from __future__ import annotations
1918
from typing import Optional, Union, Dict
2019
import ast
2120
from tvm._ffi.base import string_types
@@ -333,7 +332,7 @@ def get_attrs(self):
333332

334333
return _ffi_api.Module_GetAttrs(self)
335334

336-
def with_attr(self, attr_key, attr_value) -> IRModule:
335+
def with_attr(self, attr_key, attr_value) -> "IRModule":
337336
"""Copy the IRModule and add an attribute to it.
338337
339338
Parameters
@@ -352,7 +351,7 @@ def with_attr(self, attr_key, attr_value) -> IRModule:
352351

353352
return _ffi_api.Module_WithAttr(self, attr_key, attr_value)
354353

355-
def without_attr(self, attr_key: str) -> IRModule:
354+
def without_attr(self, attr_key: str) -> "IRModule":
356355
"""Copy the IRModule and remove an attribute key and its associated value.
357356
Parameters
358357
----------
@@ -366,7 +365,7 @@ def without_attr(self, attr_key: str) -> IRModule:
366365

367366
return _ffi_api.Module_WithoutAttr(self, attr_key)
368367

369-
def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> IRModule:
368+
def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> "IRModule":
370369
"""Copy the IRModule and add the given attribute map to it.
371370
Parameters
372371
----------

python/tvm/relax/dpl/pattern.py

+46-12
Original file line numberDiff line numberDiff line change
@@ -819,13 +819,33 @@ def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern":
819819
return PrimArrPattern(shape)
820820

821821

822+
def _is_call_tir(
823+
func_pattern: DFPattern,
824+
args: Union[List, Tuple, TuplePattern] = None,
825+
shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
826+
) -> CallPattern:
827+
if args is None:
828+
args = wildcard()
829+
elif isinstance(args, (list, tuple)):
830+
args = TuplePattern(args)
831+
832+
if shape is None:
833+
shape = wildcard()
834+
elif isinstance(shape, (list, Array)):
835+
shape = PrimArrPattern(shape)
836+
elif isinstance(shape, (tuple)):
837+
shape = is_tuple(shape) # multiple shape patterns
838+
839+
return is_op("relax.call_tir")(func_pattern, args, shape)
840+
841+
822842
def is_call_tir(
823843
func_name: str,
824844
args: Union[List, Tuple, TuplePattern] = None,
825845
shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
826846
) -> CallPattern:
827847
"""
828-
Syntax sugar for creating a CallPattern for call_tir
848+
Syntax sugar for creating a CallPattern for call_tir that calls an function through global var.
829849
830850
Parameters
831851
----------
@@ -841,19 +861,33 @@ def is_call_tir(
841861
CallPattern
842862
The resulting CallPattern
843863
"""
844-
if args is None:
845-
args = wildcard()
846-
elif isinstance(args, (list, tuple)):
847-
args = TuplePattern(args)
864+
func_pattern = GlobalVarPattern(func_name)
865+
return _is_call_tir(func_pattern, args, shape)
848866

849-
if shape is None:
850-
shape = wildcard()
851-
elif isinstance(shape, (list, Array)):
852-
shape = PrimArrPattern(shape)
853-
elif isinstance(shape, (tuple)):
854-
shape = is_tuple(shape) # multiple shape patterns
855867

856-
return is_op("relax.call_tir")(GlobalVarPattern(func_name), args, shape)
868+
def is_call_tir_extern(
869+
func_name: str,
870+
args: Union[List, Tuple, TuplePattern] = None,
871+
shape: Union[Tuple, List[tvm.ir.PrimExpr], DFPattern] = None,
872+
) -> CallPattern:
873+
"""Syntax sugar for creating a CallPattern for call_tir that calls an extern function
874+
875+
Parameters
876+
----------
877+
func_name : str
878+
Name of the CPS function to call.
879+
args : Union[List[DFPattern], Tuple[DFPattern]], optional
880+
Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments
881+
shape : Union[Tuple, List[tvm.ir.PrimExpr], DFPattern], optional
882+
Shape (or shapes in a tuple) of the output, by default None meaning arbitrary shape(s)
883+
884+
Returns
885+
-------
886+
CallPattern
887+
The resulting CallPattern
888+
"""
889+
func_pattern = ExternFuncPattern(func_name)
890+
return _is_call_tir(func_pattern, args, shape)
857891

858892

859893
def is_call_packed(

python/tvm/relax/expr.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,23 @@ def name_hint(self):
109109
return name
110110

111111
def __call__(self, *args: Any, attrs=None) -> Call:
112-
if self.checked_type and isinstance(self.checked_type, ty.FuncType):
112+
if self._checked_type_ and isinstance(self._checked_type_, ty.FuncType):
113113
return Call(self, args, attrs=attrs)
114114
else:
115-
raise TypeError("Only vars with function type can be called")
115+
raise TypeError(
116+
f"Only vars with function type can be called, but got type: {self._checked_type_}"
117+
)
118+
119+
def __getitem__(self, key):
120+
if not isinstance(key, int):
121+
raise TypeError("TupleGetItem only supports integer index")
122+
var_type = self._checked_type_
123+
if var_type and isinstance(var_type, ty.TupleType):
124+
return TupleGetItem(self, key)
125+
else:
126+
raise TypeError(
127+
f"Only vars with TupleType is subscriptable, but got type: {self._checked_type_}"
128+
)
116129

117130

118131
@tvm._ffi.register_object("relax.expr.DataflowVar")

0 commit comments

Comments
 (0)