Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

[TVMScript] New Parser: Part C #218

Merged
merged 1 commit into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions python/tvm/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""
from . import parser, parser_v1

from . import ir_builder, relax, tir
from .parser import from_source, ir_module
#############
from .parser import ir as ir_v2
from .parser import ir_module as ir_module_v2
from .parser import parse as from_source_v2
from .parser import tir as tir_v2

#############
from .parser_v1 import from_source as from_source_v1
from .parser_v1 import ir_module as ir_module_v1
from .parser_v1 import relax as relax_v1
from .parser_v1 import tir as tir_v1

# pylint: disable=invalid-name

# ir = ir_v1
ir_module = ir_module_v1
tir = tir_v1
relax = relax_v1
from_source = from_source_v1
1 change: 1 addition & 0 deletions python/tvm/script/ir_builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
# under the License.
"""tvm.script.ir_builder is a generic IR builder for TVM."""
from . import tir
from .base import IRBuilder
from .ir import ir_module
4 changes: 1 addition & 3 deletions python/tvm/script/ir_builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,4 @@
# specific language governing permissions and limitations
# under the License.
"""Package tvm.script.ir_builder.tir"""
from . import frame

# from .ir import
from .ir import * # pylint: disable=wildcard-import,redefined-builtin
21 changes: 21 additions & 0 deletions python/tvm/script/parser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the Licens.
"""The parser"""
from . import _core, ir, tir
from ._core import parse
from .ir import ir_module
from .tir import prim_func
22 changes: 22 additions & 0 deletions python/tvm/script/parser/_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the Licens.
"""The core parser infra"""
# pylint: disable=unused-import
from .core import dispatch, doc, utils
from .core.dispatch import OpMethod, register_op
from .core.entry import parse
from .core.parser import Parser
18 changes: 18 additions & 0 deletions python/tvm/script/parser/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The core parser infra"""
from . import diagnostics, dispatch, doc, doc_core, entry, evaluator, parser, utils
175 changes: 175 additions & 0 deletions python/tvm/script/parser/core/diagnostics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import inspect
import re
import sys
from typing import Union

from tvm.ir import IRModule, SourceName, Span, diagnostics

from . import doc


class Source:
source_name: str
start_line: int
start_column: int
source: str
full_source: str

def __init__(self, program: Union[str, doc.AST]):
if isinstance(program, str):
self.source_name = "<str>"
self.start_line = 1
self.start_column = 0
self.source = program
self.full_source = program
return

self.source_name = inspect.getsourcefile(program) # type: ignore
lines, self.start_line = getsourcelines(program) # type: ignore
if lines:
self.start_column = len(lines[0]) - len(lines[0].lstrip())
else:
self.start_column = 0
if self.start_column and lines:
self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
else:
self.source = "".join(lines)
try:
# It will cause a problem when running in Jupyter Notebook.
# `mod` will be <module '__main__'>, which is a built-in module
# and `getsource` will throw a TypeError
mod = inspect.getmodule(program)
if mod:
self.full_source = inspect.getsource(mod)
else:
self.full_source = self.source
except TypeError:
# It's a work around for Jupyter problem.
# Since `findsource` is an internal API of inspect, we just use it
# as a fallback method.
src, _ = inspect.findsource(program) # type: ignore
self.full_source = "".join(src)

def as_ast(self) -> doc.AST:
return doc.parse(self.source)


_getfile = inspect.getfile # pylint: disable=invalid-name
_findsource = inspect.findsource # pylint: disable=invalid-name


def _patched_inspect_getfile(obj):
if not inspect.isclass(obj):
return _getfile(obj)
mod = getattr(obj, "__module__", None)
if mod is not None:
file = getattr(sys.modules[mod], "__file__", None)
if file is not None:
return file
for _, member in inspect.getmembers(obj):
if inspect.isfunction(member):
if obj.__qualname__ + "." + member.__name__ == member.__qualname__:
return inspect.getfile(member)
raise TypeError(f"Source for {obj:!r} not found")


def findsource(obj):
import linecache # pylint: disable=import-outside-toplevel

if not inspect.isclass(obj):
return _findsource(obj)

file = inspect.getsourcefile(obj)
if file:
linecache.checkcache(file)
else:
file = inspect.getfile(obj)
if not (file.startswith("<") and file.endswith(">")):
raise OSError("source code not available")

module = inspect.getmodule(obj, file)
if module:
lines = linecache.getlines(file, module.__dict__)
else:
lines = linecache.getlines(file)
if not lines:
raise OSError("could not get source code")
qual_names = obj.__qualname__.replace(".<locals>", "<locals>").split(".")
pattern_list = []
for name in qual_names:
if name.endswith("<locals>"):
pattern_list.append(re.compile(r"^(\s*)def\s*" + name[:-8] + r"\b"))
else:
pattern_list.append(re.compile(r"^(\s*)class\s*" + name + r"\b"))
for i, line in enumerate(lines):
match = pattern_list[0].match(line)
if match:
pattern_list.pop(0)
if not pattern_list:
return lines, i
raise OSError("could not find class definition")


def getsourcelines(obj):
obj = inspect.unwrap(obj)
lines, l_num = findsource(obj)
return inspect.getblock(lines[l_num:]), l_num + 1


inspect.getfile = _patched_inspect_getfile


class Diagnostics:

source: Source
ctx: diagnostics.DiagnosticContext

def __init__(self, source: Source):
mod = IRModule()
mod.source_map.add(source.source_name, source.full_source)
self.source = source
self.ctx = diagnostics.DiagnosticContext(mod, diagnostics.get_renderer())

def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) -> None:
lineno = node.lineno or self.source.start_line
col_offset = node.col_offset or self.source.start_column
end_lineno = node.end_lineno or lineno
end_col_offset = node.end_col_offset or col_offset
lineno += self.source.start_line - 1
end_lineno += self.source.start_line - 1
col_offset += self.source.start_column + 1
end_col_offset += self.source.start_column + 1
self.ctx.emit(
diagnostics.Diagnostic(
level=level,
span=Span(
source_name=SourceName(self.source.source_name),
line=lineno,
end_line=end_lineno,
column=col_offset,
end_column=end_col_offset,
),
message=message,
)
)

def error(self, node: doc.AST, message: str) -> None:
self._emit(node, message, diagnostics.DiagnosticLevel.ERROR)
self.ctx.render()
63 changes: 63 additions & 0 deletions python/tvm/script/parser/core/dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type

from .doc import AST

if TYPE_CHECKING:
from .parser import Parser


ParseMethod = Callable[["Parser", AST], None]
ParseVTable: Dict[Tuple[str, str], ParseMethod] = {}

OpMethod = Callable[..., Any]
OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {}


def register(token: str, type_name: str):
"""Register a method for a dispatch token and type name"""

def f(method: ParseMethod):
ParseVTable[(token, type_name)] = method

return f


def get(
token: str,
type_name: str,
default: Optional[ParseMethod] = None,
) -> Optional[ParseMethod]:
return ParseVTable.get((token, type_name), default)


def register_op(ty: Type, op: AST, operand_index: int): # pylint: disable=invalid-name
def f(method: OpMethod):
OpVTable[(ty, op, operand_index)] = method

return f


def get_op( # pylint: disable=invalid-name
ty: Type,
op: Type,
operand_index: int,
default: Optional[OpMethod] = None,
) -> Optional[OpMethod]:
return OpVTable.get((ty, op, operand_index), default)
Loading