Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

[Op][Debugging] Add assert operator #260

Merged
merged 9 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ struct PrintAttrs : public tvm::AttrsNode<PrintAttrs> {
}
};

struct AssertOpAttrs : public tvm::AttrsNode<AssertOpAttrs> {
std::string format;
TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") {
TVM_ATTR_FIELD(format)
.describe(
"Python-style format string to use for displaying "
"an error message if the assert fails. "
"Ignored if empty.")
.set_default("");
}
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_OP_ATTR_TYPES_H_
15 changes: 15 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ class NameTable {
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*!
* \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype).
*
* \param ty The input type.
* \param permit_unknown_rank If true, it will permit the input type to have unknown rank
* (ndim of -1), which will require a dynamic check.
* \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype
* (namely, void), which will require a dynamic check.
*
* \return True iff the input type is a boolean scalar type (or, depending on options, has unknown
* rank or dtype)
*/
TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
bool permit_unknown_dtype = true);

} // namespace relax
} // namespace tvm

Expand Down
145 changes: 113 additions & 32 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,53 +140,55 @@ def invoke_closure(
return _ffi_api.invoke_closure(closure, args)


def render_object(val: tvm.Object) -> str:
"""
Given a TVM Object, renders it in string form. Used for Relax printing and assertions.

Parameters
----------
val: tvm.Object
An object to render

Returns
-------
ret: str
A string representing the value, ideally human-readable
"""
if isinstance(val, tvm.runtime.ndarray.NDArray):
return str(val)
# no pretty-printer by default, so if we don't handle this,
# then we can't look inside tuples
if isinstance(val, tvm.runtime.container.ADT):
# the fields array of an ADT cannot be directly accessed in Python
# so we have to get the length and index into the fields separately
fields = ", ".join([render_object(val[i]) for i in range(len(val))])
# special case: tag = 0 is a tuple
if val.tag == 0:
return f"({fields})"
return f"ADT(tag={val.tag}, fields=[{fields}])"
return str(val)


@tvm.register_func("relax.run.print")
def relax_print(*args: List[any]) -> None:
def relax_print(format_str: str, *format_args: tvm.Object) -> None:
"""
Takes a list of values to print, formats with the given format string.
If the format string is empty, simply prints.

Since this function is called as a PackedFunc from the generated code,
we cannot have it be variadic _and_ have an optional format string attribute
except by taking in all the arguments as a single list. The last argument
should be a format string.

Call from TVM script like this:
`relax.print(value1, value2, ..., valueN, format=format_str)`
or
`relax.print(value1, value2, ..., valueN) # format_str defaults to ""`

Parameters
----------
vals: List[Object]
The values to print.

format_str: str
The last argument is a Python-style format string for printing the value
"""

# there is no way to have a keyword arg to a packed function,
# so the format string is always the last argument
format_str = args[-1]
if not isinstance(format_str, str):
raise ValueError("No valid format string given.")

def render(val: tvm.Object) -> str:
if isinstance(val, tvm.runtime.ndarray.NDArray):
return str(val)
# no pretty-printer by default, so if we don't handle this,
# then we can't look inside tuples
if isinstance(val, tvm.runtime.container.ADT):
# the fields array of an ADT cannot be directly accessed in Python
# so we have to get the length and index into the fields separately
fields = ", ".join([render(val[i]) for i in range(len(val))])
# special case: tag = 0 is a tuple
if val.tag == 0:
return f"({fields})"
return f"ADT(tag={val.tag}, fields=[{fields}])"
return str(val)

val_strs = map(render, args[:-1])
format_args: List[Object]
The values to print.
"""
val_strs = map(render_object, format_args)
if format_str == "":
py_print(*val_strs)
else:
Expand Down Expand Up @@ -214,6 +216,85 @@ def print(values: Union[Expr, List[Expr]], format: str) -> Expr:
return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member


@tvm.register_func("relax.run.assert_op")
def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None:
"""
A variadic function. The first value serves as the assertion condition:
If the condition is true, then the operator does nothing.
If the condition is false, then the operator raises an assertion error.

Arguments after the first value serve as format arguments for the error message;
the last argument must be a format string for the error message (empty by default).
If the format string is the empty string, then the error message will simply include
a comma-separated list of the format arguments.
The condition argument is not included in the format string.

Parameters
----------
condition: tvm.Object
The assertion condition. Must be a boolean scalar.

format_str: str
The last argument is a Python-style format string for printing the value

format_args: List[tvm.Object]
Values used for formatting the string.
"""
if not isinstance(format_str, str):
raise ValueError(
f"The format string argument to assert must be a string, given {type(format_str)})"
)

# should be guaranteed by the type system
if not isinstance(condition, tvm.runtime.ndarray.NDArray):
raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.")

# may happen if the original program had unknown shape or dtype for the tensor's type
dtype = condition.dtype
if dtype != "bool":
raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor")
shape = condition.shape
if len(shape) != 0:
raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}")

val = condition.numpy()
if not val:
error_message = "Assertion Failed"
if format_args or format_str != "":
rendered = map(render_object, format_args)
if format_str != "":
error_message = format_str.format(*rendered)
else:
error_message = ", ".join(rendered)
raise AssertionError(error_message)


def assert_op(condition: Expr, format_args: Optional[List[Expr]] = None, format: str = "") -> Expr:
"""
Create a call to Relax's assert_op operation (`assert` is reserved in Python,
so the name must be distinct).

Parameters
----------
condition: Expr
The assertion condition.

format_args: List[Expr]
Format arguments for the error message if the condition fails.

format_str: str
The format string for the error message.

Returns
-------
result : Expr
A Call to the Relax assert operation.
"""
if format_args is None:
format_args = []
return _ffi_api.assert_op(condition, format_args, format) # type: ignore


def shape_of(expr: Expr) -> Expr:
"""Get shape of a tensor.

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,8 @@ class UniqueAttrs(Attrs):
@tvm._ffi.register_object("relax.attrs.PrintAttrs")
class PrintAttrs(Attrs):
"""Attributes used for the print operator"""


@tvm._ffi.register_object("relax.attrs.AssertOpAttrs")
class AssertOpAttrs(Attrs):
"""Attributes used for the assert operator"""
10 changes: 9 additions & 1 deletion src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,14 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
}
if (call_node->op == print_op_) {
auto print_attrs = call_node->attrs.as<PrintAttrs>();
args.push_back(EmitConstantFromValue(print_attrs->format));
// format string is the first argument
args.insert(args.begin(), EmitConstantFromValue(print_attrs->format));
return;
}
if (call_node->op == assert_op_) {
auto assert_attrs = call_node->attrs.as<AssertOpAttrs>();
// format string comes before the format args
args.insert(args.begin() + 1, EmitConstantFromValue(assert_attrs->format));
return;
}
LOG(FATAL) << "Support for attributes of Op " << call_node->op
Expand Down Expand Up @@ -520,6 +527,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& unique_op_ = Op::Get("relax.unique");
const Op& print_op_ = Op::Get("relax.print");
const Op& assert_op_ = Op::Get("relax.assert_op");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
};
Expand Down
46 changes: 46 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/attrs/shape.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/utils.h>
#include <tvm/relay/op.h>

#include "op_common.h"
Expand Down Expand Up @@ -118,6 +119,51 @@ Expr MakePrint(Array<Expr> vals, std::string format) {

TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint);

// assert_op

// can't actually name it assert or else Python will consider it a syntax error

Type InferAssertType(const Call& call, DiagnosticContext diag_ctx) {
// Ensure that the condition argument is a boolean scalar.
// Also permitted is a tensor with unknown shape and unknown dtype
// (checked dynamically in that case). Returns void.
if (call->args.size() < 1) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "Assert must have at least one argument (the condition).");
}
Type arg_type = call->args[0]->checked_type();
if (!IsBoolScalarType(arg_type)) {
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
<< "The argument to assert must be a boolean scalar type, but received "
<< arg_type);
}
return VoidType();
}

TVM_REGISTER_NODE_TYPE(AssertOpAttrs);

RELAY_REGISTER_OP("relax.assert_op")
.set_attrs_type<AssertOpAttrs>()
.set_num_inputs(-1)
.add_argument("vals", "Array<Expr>",
"The first value is used as the assertion condition. The others are used as "
"format arguments if there is an error.")
.set_attr<FInferType>("FInferType", InferAssertType)
.set_attr<FCallPacked>("FCallPacked", "relax.run.assert_op");

Expr MakeAssertOp(Expr condition, Array<Expr> vals, std::string format) {
auto attrs = make_object<AssertOpAttrs>();
attrs->format = format;
static const Op& op = Op::Get("relax.assert_op");
Array<Expr> args = {condition};
for (auto val : vals) {
args.push_back(val);
}
return Call(op, args, Attrs(attrs));
}

TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp);

// make_closure

RELAY_REGISTER_OP("relax.make_closure")
Expand Down
10 changes: 10 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,15 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
}
}

bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) {
const DynTensorTypeNode* tt = ty.as<DynTensorTypeNode>();
if (!tt) {
return false;
}
bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype && tt->dtype.is_void());
bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1);
return correct_dtype && correct_rank;
}

} // namespace relax
} // namespace tvm
56 changes: 56 additions & 0 deletions tests/python/relax/test_relax_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest
import tvm
from tvm import relax
from tvm._ffi.base import TVMError

from tvm.script import relax as R

Expand Down Expand Up @@ -88,5 +89,60 @@ def test_print():
sys.stdout = stdout


@tvm.script.ir_module
class AssertOpTest:
@R.function
def passes(x: Tensor((), "int32")):
p1 = relax.assert_op(relax.const(True))
return x

@R.function
def pass_with_args(x: Tensor((), "int32")):
p1 = relax.assert_op(relax.const(True), x, format="You won't see me")
return x

@R.function
def simple_fail(x: Tensor((), "int32")):
p1 = relax.assert_op(relax.const(False))
return x

@R.function
def fail_with_message(x: Tensor((), "int32")):
p1 = relax.assert_op(relax.const(False), format="I failed...")
return x

@R.function
def fail_with_args(x: Tensor((), "int32")):
# no format
p1 = relax.assert_op(relax.const(False), x, x)
return x

@R.function
def fail_with_formatted_message(x: Tensor((), "int32")):
p1 = relax.assert_op(relax.const(False), x, format="Number: {}")
return x


def test_assert_op():
def check_assertion_error(func_name, func_arg, expected_message):
passed = False
try:
run_cpu(AssertOpTest, func_name, func_arg)
passed = True
except TVMError as e:
# TVM will print out a TVMError that will contain the
# generated error at the bottom of a stack trace
assert "AssertionError" in e.args[0]
assert expected_message in e.args[0]
assert not passed

run_cpu(AssertOpTest, "passes", tvm.nd.array(1))
run_cpu(AssertOpTest, "pass_with_args", tvm.nd.array(2))
check_assertion_error("simple_fail", tvm.nd.array(3), "Assertion Failed")
check_assertion_error("fail_with_message", tvm.nd.array(4), "I failed...")
check_assertion_error("fail_with_args", tvm.nd.array(5), "5, 5")
check_assertion_error("fail_with_formatted_message", tvm.nd.array(6), "Number: 6")


if __name__ == "__main__":
pytest.main([__file__])