Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Commit 5492637

Browse files
authored
[TVMScript] New Parser: Part C (#218)
1 parent d430754 commit 5492637

40 files changed

+1932
-506
lines changed

python/tvm/script/__init__.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
18+
from . import parser, parser_v1
1819

19-
from . import ir_builder, relax, tir
20-
from .parser import from_source, ir_module
20+
#############
21+
from .parser import ir as ir_v2
22+
from .parser import ir_module as ir_module_v2
23+
from .parser import parse as from_source_v2
24+
from .parser import tir as tir_v2
25+
26+
#############
27+
from .parser_v1 import from_source as from_source_v1
28+
from .parser_v1 import ir_module as ir_module_v1
29+
from .parser_v1 import relax as relax_v1
30+
from .parser_v1 import tir as tir_v1
31+
32+
# pylint: disable=invalid-name
33+
34+
# ir = ir_v1
35+
ir_module = ir_module_v1
36+
tir = tir_v1
37+
relax = relax_v1
38+
from_source = from_source_v1

python/tvm/script/ir_builder/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
# under the License.
1717
"""tvm.script.ir_builder is a generic IR builder for TVM."""
1818
from . import tir
19+
from .base import IRBuilder
1920
from .ir import ir_module

python/tvm/script/ir_builder/tir/__init__.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,4 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Package tvm.script.ir_builder.tir"""
18-
from . import frame
19-
20-
# from .ir import
18+
from .ir import * # pylint: disable=wildcard-import,redefined-builtin

python/tvm/script/parser/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the Licens.
17+
"""The parser"""
18+
from . import _core, ir, tir
19+
from ._core import parse
20+
from .ir import ir_module
21+
from .tir import prim_func

python/tvm/script/parser/_core.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the Licens.
17+
"""The core parser infra"""
18+
# pylint: disable=unused-import
19+
from .core import dispatch, doc, utils
20+
from .core.dispatch import OpMethod, register_op
21+
from .core.entry import parse
22+
from .core.parser import Parser
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""The core parser infra"""
18+
from . import diagnostics, dispatch, doc, doc_core, entry, evaluator, parser, utils
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring
18+
import inspect
19+
import re
20+
import sys
21+
from typing import Union
22+
23+
from tvm.ir import IRModule, SourceName, Span, diagnostics
24+
25+
from . import doc
26+
27+
28+
class Source:
29+
source_name: str
30+
start_line: int
31+
start_column: int
32+
source: str
33+
full_source: str
34+
35+
def __init__(self, program: Union[str, doc.AST]):
36+
if isinstance(program, str):
37+
self.source_name = "<str>"
38+
self.start_line = 1
39+
self.start_column = 0
40+
self.source = program
41+
self.full_source = program
42+
return
43+
44+
self.source_name = inspect.getsourcefile(program) # type: ignore
45+
lines, self.start_line = getsourcelines(program) # type: ignore
46+
if lines:
47+
self.start_column = len(lines[0]) - len(lines[0].lstrip())
48+
else:
49+
self.start_column = 0
50+
if self.start_column and lines:
51+
self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
52+
else:
53+
self.source = "".join(lines)
54+
try:
55+
# It will cause a problem when running in Jupyter Notebook.
56+
# `mod` will be <module '__main__'>, which is a built-in module
57+
# and `getsource` will throw a TypeError
58+
mod = inspect.getmodule(program)
59+
if mod:
60+
self.full_source = inspect.getsource(mod)
61+
else:
62+
self.full_source = self.source
63+
except TypeError:
64+
# It's a work around for Jupyter problem.
65+
# Since `findsource` is an internal API of inspect, we just use it
66+
# as a fallback method.
67+
src, _ = inspect.findsource(program) # type: ignore
68+
self.full_source = "".join(src)
69+
70+
def as_ast(self) -> doc.AST:
71+
return doc.parse(self.source)
72+
73+
74+
_getfile = inspect.getfile # pylint: disable=invalid-name
75+
_findsource = inspect.findsource # pylint: disable=invalid-name
76+
77+
78+
def _patched_inspect_getfile(obj):
79+
if not inspect.isclass(obj):
80+
return _getfile(obj)
81+
mod = getattr(obj, "__module__", None)
82+
if mod is not None:
83+
file = getattr(sys.modules[mod], "__file__", None)
84+
if file is not None:
85+
return file
86+
for _, member in inspect.getmembers(obj):
87+
if inspect.isfunction(member):
88+
if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
89+
return inspect.getfile(member)
90+
raise TypeError(f"Source for {obj:!r} not found")
91+
92+
93+
def findsource(obj):
94+
import linecache # pylint: disable=import-outside-toplevel
95+
96+
if not inspect.isclass(obj):
97+
return _findsource(obj)
98+
99+
file = inspect.getsourcefile(obj)
100+
if file:
101+
linecache.checkcache(file)
102+
else:
103+
file = inspect.getfile(obj)
104+
if not (file.startswith("<") and file.endswith(">")):
105+
raise OSError("source code not available")
106+
107+
module = inspect.getmodule(obj, file)
108+
if module:
109+
lines = linecache.getlines(file, module.__dict__)
110+
else:
111+
lines = linecache.getlines(file)
112+
if not lines:
113+
raise OSError("could not get source code")
114+
qual_names = obj.__qualname__.replace(".<locals>", "<locals>").split(".")
115+
pattern_list = []
116+
for name in qual_names:
117+
if name.endswith("<locals>"):
118+
pattern_list.append(re.compile(r"^(\s*)def\s*" + name[:-8] + r"\b"))
119+
else:
120+
pattern_list.append(re.compile(r"^(\s*)class\s*" + name + r"\b"))
121+
for i, line in enumerate(lines):
122+
match = pattern_list[0].match(line)
123+
if match:
124+
pattern_list.pop(0)
125+
if not pattern_list:
126+
return lines, i
127+
raise OSError("could not find class definition")
128+
129+
130+
def getsourcelines(obj):
131+
obj = inspect.unwrap(obj)
132+
lines, l_num = findsource(obj)
133+
return inspect.getblock(lines[l_num:]), l_num + 1
134+
135+
136+
inspect.getfile = _patched_inspect_getfile
137+
138+
139+
class Diagnostics:
140+
141+
source: Source
142+
ctx: diagnostics.DiagnosticContext
143+
144+
def __init__(self, source: Source):
145+
mod = IRModule()
146+
mod.source_map.add(source.source_name, source.full_source)
147+
self.source = source
148+
self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())
149+
150+
def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
151+
lineno = node.lineno or self.source.start_line
152+
col_offset = node.col_offset or self.source.start_column
153+
end_lineno = node.end_lineno or lineno
154+
end_col_offset = node.end_col_offset or col_offset
155+
lineno += self.source.start_line - 1
156+
end_lineno += self.source.start_line - 1
157+
col_offset += self.source.start_column + 1
158+
end_col_offset += self.source.start_column + 1
159+
self.ctx.emit(
160+
diagnostics.Diagnostic(
161+
level=level,
162+
span=Span(
163+
source_name=SourceName(self.source.source_name),
164+
line=lineno,
165+
end_line=end_lineno,
166+
column=col_offset,
167+
end_column=end_col_offset,
168+
),
169+
message=message,
170+
)
171+
)
172+
173+
def error(self, node: doc.AST, message: str) -> None:
174+
self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
175+
self.ctx.render()
+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-docstring
18+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type
19+
20+
from .doc import AST
21+
22+
if TYPE_CHECKING:
23+
from .parser import Parser
24+
25+
26+
ParseMethod = Callable[["Parser", AST], None]
27+
ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}
28+
29+
OpMethod = Callable[..., Any]
30+
OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}
31+
32+
33+
def register(token: str, type_name: str):
34+
"""Register a method for a dispatch token and type name"""
35+
36+
def f(method: ParseMethod):
37+
ParseVTable[(token, type_name)] = method
38+
39+
return f
40+
41+
42+
def get(
43+
token: str,
44+
type_name: str,
45+
default: Optional[ParseMethod] = None,
46+
) -> Optional[ParseMethod]:
47+
return ParseVTable.get((token, type_name), default)
48+
49+
50+
def register_op(ty: Type, op: AST, operand_index: int): # pylint: disable=invalid-name
51+
def f(method: OpMethod):
52+
OpVTable[(ty, op, operand_index)] = method
53+
54+
return f
55+
56+
57+
def get_op( # pylint: disable=invalid-name
58+
ty: Type,
59+
op: Type,
60+
operand_index: int,
61+
default: Optional[OpMethod] = None,
62+
) -> Optional[OpMethod]:
63+
return OpVTable.get((ty, op, operand_index), default)

0 commit comments

Comments
 (0)