Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Commit 32a03f8

Browse files
authored
[TVMScript] Update Type Annotation Behavior of the Parser (#269)
This commit changes the behavior of the parser to allow type annotations, as suggested by the community. The current behavior: - Use the more refined type/shape between user annotated and deduced type/shape. The updated behavior: - Always use user annotations - Only checks if the type/shape is valid.
1 parent c6d6a06 commit 32a03f8

File tree

4 files changed

+56
-53
lines changed

4 files changed

+56
-53
lines changed

include/tvm/script/ir_builder/relax/ir.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,13 @@ TVM_DLL Optional<tvm::relax::Var> EmitMatchShape(const tvm::relax::Expr& value,
155155
/*!
156156
* \brief Annotate and check the type and shape of relax var.
157157
* \param var The input var to be annotated.
158-
* \param type The given type.
159-
* \param shape The given shape, which can be undefined.
160-
* \note This function will check if the type of var is compatible with the given type.
158+
* \param anno_type The annotated type.
159+
* \param anno_shape The annotated shape, which can be undefined.
160+
* \note This function will check if the type of var is compatible with the annotated type.
161161
* And we annotate to the var with more detailed type.
162162
*/
163-
TVM_DLL void AnnotateTypeShape(const tvm::relax::Var& var, const Type& type,
164-
const Optional<tvm::relax::ShapeExpr>& shape);
163+
TVM_DLL void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type,
164+
const Optional<tvm::relax::ShapeExpr>& anno_shape);
165165

166166
///////////////////////////// If Then Else /////////////////////////////
167167

python/tvm/script/ir_builder/relax/ir.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -287,18 +287,21 @@ def emit_match_shape(
287287
############################# Type Deduce ##############################
288288

289289

290-
def annotate_type_shape(var: Var, type: Type, shape: ShapeExpr) -> None:
290+
def annotate_type_shape(var: Var, anno_type: Type, anno_shape: ShapeExpr) -> None:
291291
"""Annotate and check the type of relax var.
292292
Parameters
293293
----------
294294
var: Var
295295
The input var to be annotated.
296-
type: Type
297-
The given type
298-
shape: ShapeExpr
299-
The given shape
296+
297+
anno_type: Type
298+
The annotated type
299+
300+
anno_shape: ShapeExpr
301+
The annotated shape
302+
300303
"""
301-
_ffi_api.AnnotateTypeShape(var, type, shape)
304+
_ffi_api.AnnotateTypeShape(var, anno_type, anno_shape)
302305

303306

304307
def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name

src/script/ir_builder/relax/ir.cc

+12-21
Original file line numberDiff line numberDiff line change
@@ -256,35 +256,26 @@ TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchShape").set_body_typed(Emi
256256

257257
///////////////////////////// Type Deduce //////////////////////////////
258258

259-
void AnnotateTypeShape(const tvm::relax::Var& var, const Type& type,
260-
const Optional<tvm::relax::ShapeExpr>& shape) {
259+
void AnnotateTypeShape(const tvm::relax::Var& var, const Type& anno_type,
260+
const Optional<tvm::relax::ShapeExpr>& anno_shape) {
261261
using tvm::relax::IsBaseOf;
262-
if (!var->checked_type_.defined()) {
263-
var->checked_type_ = type;
264-
} else {
262+
if (var->checked_type_.defined()) {
265263
const Type& var_type = var->checked_type();
266-
if (IsBaseOf(type, var_type)) {
267-
// The var type is equal or more detailed than annotated one, do nothing.
268-
} else if (IsBaseOf(var_type, type)) {
269-
LOG(WARNING) << "The inferred type of var " << var->name_hint()
270-
<< " by the block builder is more refined than the annotated one. The system "
271-
"will refine it automatically.";
272-
var->checked_type_ = type;
273-
} else {
274-
LOG(FATAL) << "TypeError: The annotated type and value type are not compatible. "
275-
<< "The Type is expected to be " << var_type << " but got annotation: " << type;
276-
}
264+
CHECK(IsBaseOf(anno_type, var_type) || IsBaseOf(var_type, anno_type))
265+
<< "TypeError: The annotated type and value type are not compatible. "
266+
<< "The Type is expected to be " << var_type << " but got annotation: " << anno_type;
277267
}
278268

279-
if (!var->shape_.defined()) {
280-
var->shape_ = shape;
281-
} else if (shape.defined()) {
269+
if (var->shape_.defined() && anno_shape.defined()) {
282270
const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder();
283271
tvm::relax::Expr var_shape = Downcast<tvm::relax::Expr>(var->shape_.value());
284-
CHECK(block_builder->CanProveShapeEqual(var_shape, shape.value()))
272+
CHECK(block_builder->CanProveShapeEqual(var_shape, anno_shape.value()))
285273
<< " The shape of var " << var->name_hint() << " is expected to be " << var_shape
286-
<< " but got annotation: " << shape.value();
274+
<< " but got annotation: " << anno_shape.value();
287275
}
276+
277+
var->checked_type_ = anno_type;
278+
var->shape_ = anno_shape;
288279
}
289280

290281
TVM_REGISTER_GLOBAL("script.ir_builder.relax.AnnotateTypeShape").set_body_typed(AnnotateTypeShape);

tests/python/relax/test_tvmscript_parser.py

+30-21
Original file line numberDiff line numberDiff line change
@@ -427,28 +427,37 @@ def foo(
427427
o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, type_args=R.Object)
428428
return o
429429

430-
m = tir.Var("m", "int64")
431-
x = relax.Var("x", (32, m), relax.DynTensorType(2, "float32"))
432-
y = relax.Var("y", (m,), relax.DynTensorType(1, "float32"))
433-
r = relax.Var("r", None, relax.DynTensorType(-1, "int64"))
434-
bb = relax.BlockBuilder()
435-
with bb.function("foo", (x, y, r)):
436-
z = bb.emit(R.multiply(x, y))
437-
w = bb.emit(R.multiply(z, z))
438-
q = bb.emit(R.add(w, w))
439-
t = bb.emit(R.add(w, z))
440-
sh = bb.emit(R.shape_of(t))
441-
o = bb.emit(
442-
relax.Call(
443-
relax.ExternFunc("contrib.tensor_array_stack"),
444-
[x, y],
445-
None,
446-
type_args=[relax.ObjectType()],
447-
)
448-
)
449-
bb.emit_func_output(o)
430+
def _check_type_shape(binding, expected_type, expected_shape):
431+
tvm.ir.assert_structural_equal(binding.var.checked_type, expected_type)
432+
tvm.ir.assert_structural_equal(binding.var.shape_, expected_shape)
433+
434+
# Cannot use block builder here because we need to check the annotated type,
435+
# which may be inconsistent with deduced type.
436+
assert isinstance(foo.ret_type, relax.ObjectType)
437+
m = foo.params[0].shape[1]
438+
bindings = foo.body.blocks[0].bindings
439+
_check_type_shape(
440+
bindings[0], relax.DynTensorType(ndim=2, dtype="float32"), relax.ShapeExpr([32, m])
441+
)
442+
_check_type_shape(bindings[1], relax.DynTensorType(dtype=""), None)
443+
_check_type_shape(bindings[2], relax.DynTensorType(ndim=2, dtype=""), None)
444+
_check_type_shape(bindings[3], relax.DynTensorType(dtype=""), None)
445+
_check_type_shape(bindings[4], relax.ShapeType(), None)
446+
_check_type_shape(bindings[5], relax.ObjectType(), None)
447+
448+
449+
def test_annotate_override():
450+
@R.function
451+
def foo(x: R.Tensor):
452+
y = x
453+
# z will be treated as object type even though it's a tensor
454+
z: R.Object = y
455+
return z
450456

451-
_check(foo, bb.get()["foo"])
457+
assert isinstance(foo.ret_type, relax.ObjectType)
458+
y_bind, z_bind = foo.body.blocks[0].bindings
459+
assert isinstance(y_bind.var.checked_type, relax.DynTensorType)
460+
assert isinstance(z_bind.var.checked_type, relax.ObjectType)
452461

453462

454463
def test_empty_shape():

0 commit comments

Comments
 (0)