Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Commit db5dcf1

Browse files
Josh Frommjunrushao
Josh Fromm
authored andcommitted
[Call TIR] Fix bug when invoking call_tir with scalar values. (#254)
This small PR changes a check in the tvmscript parser to support empty shape tuples which are used to represent scalars. I added a scalar addition test to make sure it works properly.
1 parent de59492 commit db5dcf1

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

python/tvm/script/relax/parser.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1399,7 +1399,8 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr:
13991399
elif isinstance(expr, ast.Tuple):
14001400
fields = [self.transform_expr(field) for field in expr.values]
14011401

1402-
if all([isinstance(f, str) for f in fields]):
1402+
# Empty shape tuples should be treated as shape expressions.
1403+
if all([isinstance(f, str) for f in fields]) and len(fields) != 0:
14031404
return tuple(fields)
14041405

14051406
# TODO(@altanh): this check might be too weak; we really only accept integral PrimExprs

tests/python/relax/test_parser.py

+28
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,34 @@ def f(x: Tensor) -> Tensor:
721721
)
722722

723723

724+
def test_empty_shape():
725+
@R.function
726+
def f(x: Tensor((), "float32"), y: Tensor((), "float32")):
727+
@T.prim_func
728+
def scalar_add(a: T.handle, b: T.handle, c: T.handle) -> None:
729+
A = T.match_buffer(a, ())
730+
B = T.match_buffer(b, ())
731+
C = T.match_buffer(c, ())
732+
733+
with T.block("add"):
734+
C[()] = A[()] + B[()]
735+
736+
z = relax.call_tir(scalar_add, (x, y), (), dtype="float32")
737+
return z
738+
739+
x, y = f.params
740+
add_bind, z_bind = f.body.blocks[0].bindings
741+
742+
assert add_bind.var.name_hint == "scalar_add"
743+
assert isinstance(add_bind.value, tir.PrimFunc)
744+
745+
check_call(
746+
z_bind.value,
747+
"relax.call_tir",
748+
[add_bind.var, relax.Tuple([x, y]), relax.ShapeExpr([])],
749+
)
750+
751+
724752
def test_class_irmodule():
725753
@tvm.script.ir_module
726754
class MyModule:

0 commit comments

Comments
 (0)