Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Commit 97ffe01

Browse files
authored
[TVMScript] A1: Relax Parser infra (#240)
This PR introduces the tvm script parser for relax dialect. Note that it's only a PR for infrastructure, which does not support all features yet.
1 parent 440e5f1 commit 97ffe01

File tree

9 files changed

+244
-19
lines changed

9 files changed

+244
-19
lines changed

python/tvm/script/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .parser import ir_module as ir_module_v2
2323
from .parser import parse as from_source_v2
2424
from .parser import tir as tir_v2
25+
from .parser import relax as relax_v2
2526

2627
#############
2728
from .parser_v1 import from_source as from_source_v1

python/tvm/script/parser/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the Licens.
1717
"""The parser"""
18-
from . import _core, ir, tir
18+
from . import _core, ir, tir, relax
1919
from ._core import parse
2020
from .ir import ir_module
2121
from .tir import prim_func
22+
from .relax import function

python/tvm/script/parser/ir/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616
# under the License.
1717
# pylint: disable=missing-docstring
1818
from . import parser as _parser
19-
from .entry import ir_module
19+
from .entry import ir_module, is_defined_in_class
2020

21-
__all__ = ["ir_module"]
21+
__all__ = ["ir_module", "is_defined_in_class"]

python/tvm/script/parser/ir/entry.py

+15
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,21 @@
2323
from .._core import parse, utils
2424

2525

26+
def is_defined_in_class(frames):
27+
if len(frames) > 2:
28+
maybe_class_frame = frames[2]
29+
statement_list = maybe_class_frame[4]
30+
if statement_list is None:
31+
return False
32+
first_statement = statement_list[0]
33+
line = first_statement.strip()
34+
if line.startswith("class "):
35+
return True
36+
if line.startswith("@") and "ir_module" in line:
37+
return True
38+
return False
39+
40+
2641
def ir_module(f: Type) -> IRModule:
2742
if not inspect.isclass(f):
2843
raise TypeError(f"Expect a class, but got: {f}")
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring
18+
from ...ir_builder.relax import * # pylint: disable=redefined-builtin
19+
from ...ir_builder.relax import ir as _relax
20+
from . import parser as _parser
21+
from .entry import function, Tensor
22+
23+
24+
__all__ = _relax.__all__ + ["function", "Tensor"]
+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring
18+
import inspect
19+
from typing import Callable, List, Optional, Union, TypeVar
20+
21+
from tvm.relax import Function, Var
22+
from tvm.tir import PrimExpr
23+
24+
from ...ir_builder.relax import tensor, TensorType
25+
from .._core import parse, utils
26+
from ..ir import is_defined_in_class
27+
28+
29+
FType = TypeVar("FType", bound=Callable)
30+
31+
32+
def function(f: FType) -> Union[Function, FType]:
33+
if not inspect.isfunction(f):
34+
raise TypeError(f"Expect a function, but got: {f}")
35+
if is_defined_in_class(inspect.stack()):
36+
return f
37+
return parse(f, utils.inspect_function_capture(f))
38+
39+
40+
setattr(function, "dispatch_token", "relax")
41+
42+
43+
class TensorProxy:
44+
def __call__(
45+
self,
46+
shape: Optional[List[Union[PrimExpr, str]]] = None,
47+
dtype: str = None,
48+
ndim: int = -1,
49+
) -> TensorType:
50+
return tensor(shape, dtype, ndim)
51+
52+
def __getitem__(self, keys) -> Var:
53+
return self(*keys) # pylint: disable=no-member # type: ignore
54+
55+
56+
Tensor = TensorProxy() # pylint: disable=invalid-name
+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring,
18+
19+
from typing import Any
20+
21+
from tvm import relax
22+
23+
from ...ir_builder import relax as R
24+
from ...ir_builder.base import IRBuilder
25+
from .._core import Parser, dispatch, doc
26+
27+
28+
def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any:
29+
# pylint: disable=unused-argument
30+
if isinstance(value, relax.Expr):
31+
var = R.emit(value)
32+
IRBuilder.name(var_name, var)
33+
return var
34+
else:
35+
raise TypeError(f"Unsupported type {type(value)} in assignment")
36+
37+
38+
@dispatch.register(token="relax", type_name="FunctionDef")
39+
def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
40+
with self.var_table.with_frame():
41+
with R.function():
42+
R.func_name(node.name)
43+
if node.returns is not None:
44+
R.func_ret_type(self.eval_expr(node.returns).type)
45+
with self.with_dispatch_token("relax"):
46+
self.visit(node.args)
47+
self.visit_body(node.body)
48+
49+
50+
@dispatch.register(token="relax", type_name="Expr")
51+
def visit_expr_stmt(self: Parser, node: doc.FunctionDef) -> None:
52+
self.eval_expr(node.value)
53+
54+
55+
@dispatch.register(token="relax", type_name="arguments")
56+
def visit_arguments(self: Parser, node: doc.arguments) -> None:
57+
arg: doc.arg
58+
for arg in node.args:
59+
if arg.annotation is None:
60+
self.report_error(arg, "Type annotation is required for function parameters.")
61+
param_type = self.visit_tvm_annotation(arg.annotation)
62+
param = R.arg(arg.arg, param_type)
63+
64+
self.var_table.add(arg.arg, param)
65+
66+
67+
@dispatch.register(token="relax", type_name="tvm_annotation")
68+
def visit_tvm_annotation(self: Parser, node: doc.expr):
69+
annotation = self.eval_expr(node)
70+
if callable(annotation):
71+
annotation = annotation()
72+
return annotation
73+
74+
75+
@dispatch.register(token="relax", type_name="Assign")
76+
def visit_assign(self: Parser, node: doc.Assign) -> None:
77+
if len(node.targets) != 1:
78+
self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.")
79+
lhs = node.targets[0]
80+
rhs = self.eval_expr(node.value)
81+
self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value)
82+
83+
84+
@dispatch.register(token="relax", type_name="Return")
85+
def visit_return(self: Parser, node: doc.Assign) -> None:
86+
value = self.eval_expr(node.value)
87+
88+
if isinstance(value, relax.Expr):
89+
R.func_ret_value(value)
90+
else:
91+
self.report_error(node, f"Unsupported return value type {type(value)}.")

python/tvm/script/parser/tir/entry.py

+2-16
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,13 @@
2222

2323
from ...ir_builder.tir import buffer_decl, ptr
2424
from .._core import parse, utils
25-
26-
27-
def _is_defined_in_class(frames):
28-
if len(frames) > 2:
29-
maybe_class_frame = frames[2]
30-
statement_list = maybe_class_frame[4]
31-
if statement_list is None:
32-
return False
33-
first_statement = statement_list[0]
34-
line = first_statement.strip()
35-
if line.startswith("class "):
36-
return True
37-
if line.startswith("@") and "ir_module" in line:
38-
return True
39-
return False
25+
from ..ir import is_defined_in_class
4026

4127

4228
def prim_func(f: Callable) -> Union[PrimFunc, Callable]:
4329
if not inspect.isfunction(f):
4430
raise TypeError(f"Expect a function, but got: {f}")
45-
if _is_defined_in_class(inspect.stack()):
31+
if is_defined_in_class(inspect.stack()):
4632
return f
4733
return parse(f, utils.inspect_function_capture(f))
4834

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from typing import Union
18+
import tvm
19+
import tvm.testing
20+
21+
from tvm import relax
22+
from tvm import IRModule
23+
from tvm.script.parser import ir as I, tir as T, relax as R
24+
25+
26+
def _check(
27+
parsed: Union[relax.Function, IRModule],
28+
expect: Union[relax.Function, IRModule],
29+
):
30+
# TODO(siyuan): add round-trip tests
31+
tvm.ir.assert_structural_equal(parsed, expect)
32+
33+
34+
def test_simple_func():
35+
@R.function
36+
def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2):
37+
R.func_attr({"Primitive": 1})
38+
gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32")
39+
return gv0
40+
41+
x = relax.Var("x", [128, 128], relax.DynTensorType(2, "float32"))
42+
bb = relax.BlockBuilder()
43+
with bb.function("foo", (x,), attrs={"Primitive": 1}):
44+
out = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32"))
45+
bb.emit_func_output(out)
46+
47+
_check(foo, bb.get()["foo"])
48+
49+
50+
if __name__ == "__main__":
51+
tvm.testing.main()

0 commit comments

Comments
 (0)