Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Commit d0659c0

Browse files
authored
[Op][Debugging] Add assert operator (#260)
It was brought up that Relay lacks an assert operator, so we may as well have one in Relax for debugging. One issue is that we can't name it "`assert`" because Python will treat it as a syntax error to have it as a field name for the "`relax`" module, i.e., `relax.assert` is a syntax error. Thus the op is named "`assert_op`," which is not ideal but serves its purpose.
1 parent 9fa3f31 commit d0659c0

File tree

8 files changed

+266
-33
lines changed

8 files changed

+266
-33
lines changed

include/tvm/relax/op_attr_types.h

+12
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,18 @@ struct PrintAttrs : public tvm::AttrsNode<PrintAttrs> {
9797
}
9898
};
9999

100+
struct AssertOpAttrs : public tvm::AttrsNode<AssertOpAttrs> {
101+
std::string format;
102+
TVM_DECLARE_ATTRS(AssertOpAttrs, "relax.attrs.AssertOpAttrs") {
103+
TVM_ATTR_FIELD(format)
104+
.describe(
105+
"Python-style format string to use for displaying "
106+
"an error message if the assert fails. "
107+
"Ignored if empty.")
108+
.set_default("");
109+
}
110+
};
111+
100112
} // namespace relax
101113
} // namespace tvm
102114
#endif // TVM_RELAX_OP_ATTR_TYPES_H_

include/tvm/relax/utils.h

+15
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,21 @@ class NameTable {
107107
*/
108108
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
109109

110+
/*!
111+
* \brief Check if the given type is a boolean scalar type (tensor of rank 0 with a boolean dtype).
112+
*
113+
* \param ty The input type.
114+
* \param permit_unknown_rank If true, it will permit the input type to have unknown rank
115+
* (ndim of -1), which will require a dynamic check.
116+
* \param permit_unknown_dtype If true, it will permit the input type to have an unknown dtype
117+
* (namely, void), which will require a dynamic check.
118+
*
119+
* \return True iff the input type is a boolean scalar type (or, depending on options, has unknown
120+
* rank or dtype)
121+
*/
122+
TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
123+
bool permit_unknown_dtype = true);
124+
110125
} // namespace relax
111126
} // namespace tvm
112127

python/tvm/relax/op/base.py

+113-32
Original file line numberDiff line numberDiff line change
@@ -140,53 +140,55 @@ def invoke_closure(
140140
return _ffi_api.invoke_closure(closure, args)
141141

142142

143+
def render_object(val: tvm.Object) -> str:
144+
"""
145+
Given a TVM Object, renders it in string form. Used for Relax printing and assertions.
146+
147+
Parameters
148+
----------
149+
val: tvm.Object
150+
An object to render
151+
152+
Returns
153+
-------
154+
ret: str
155+
A string representing the value, ideally human-readable
156+
"""
157+
if isinstance(val, tvm.runtime.ndarray.NDArray):
158+
return str(val)
159+
# no pretty-printer by default, so if we don't handle this,
160+
# then we can't look inside tuples
161+
if isinstance(val, tvm.runtime.container.ADT):
162+
# the fields array of an ADT cannot be directly accessed in Python
163+
# so we have to get the length and index into the fields separately
164+
fields = ", ".join([render_object(val[i]) for i in range(len(val))])
165+
# special case: tag = 0 is a tuple
166+
if val.tag == 0:
167+
return f"({fields})"
168+
return f"ADT(tag={val.tag}, fields=[{fields}])"
169+
return str(val)
170+
171+
143172
@tvm.register_func("relax.run.print")
144-
def relax_print(*args: List[any]) -> None:
173+
def relax_print(format_str: str, *format_args: tvm.Object) -> None:
145174
"""
146175
Takes a list of values to print, formats with the given format string.
147176
If the format string is empty, simply prints.
148177
149-
Since this function is called as a PackedFunc from the generated code,
150-
we cannot have it be variadic _and_ have an optional format string attribute
151-
except by taking in all the arguments as a single list. The last argument
152-
should be a format string.
153-
154178
Call from TVM script like this:
155179
`relax.print(value1, value2, ..., valueN, format=format_str)`
156180
or
157181
`relax.print(value1, value2, ..., valueN) # format_str defaults to ""`
158182
159183
Parameters
160184
----------
161-
vals: List[Object]
162-
The values to print.
163-
164185
format_str: str
165186
The last argument is a Python-style format string for printing the value
166-
"""
167-
168-
# there is no way to have a keyword arg to a packed function,
169-
# so the format string is always the last argument
170-
format_str = args[-1]
171-
if not isinstance(format_str, str):
172-
raise ValueError("No valid format string given.")
173-
174-
def render(val: tvm.Object) -> str:
175-
if isinstance(val, tvm.runtime.ndarray.NDArray):
176-
return str(val)
177-
# no pretty-printer by default, so if we don't handle this,
178-
# then we can't look inside tuples
179-
if isinstance(val, tvm.runtime.container.ADT):
180-
# the fields array of an ADT cannot be directly accessed in Python
181-
# so we have to get the length and index into the fields separately
182-
fields = ", ".join([render(val[i]) for i in range(len(val))])
183-
# special case: tag = 0 is a tuple
184-
if val.tag == 0:
185-
return f"({fields})"
186-
return f"ADT(tag={val.tag}, fields=[{fields}])"
187-
return str(val)
188187
189-
val_strs = map(render, args[:-1])
188+
format_args: List[Object]
189+
The values to print.
190+
"""
191+
val_strs = map(render_object, format_args)
190192
if format_str == "":
191193
py_print(*val_strs)
192194
else:
@@ -214,6 +216,85 @@ def print(values: Union[Expr, List[Expr]], format: str) -> Expr:
214216
return _ffi_api.print(values, format) # type: ignore # pylint: disable=no-member
215217

216218

219+
@tvm.register_func("relax.run.assert_op")
220+
def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Object) -> None:
221+
"""
222+
A variadic function. The first value serves as the assertion condition:
223+
If the condition is true, then the operator does nothing.
224+
If the condition is false, then the operator raises an assertion error.
225+
226+
Arguments after the first value serve as format arguments for the error message;
227+
the last argument must be a format string for the error message (empty by default).
228+
If the format string is the empty string, then the error message will simply include
229+
a comma-separated list of the format arguments.
230+
The condition argument is not included in the format string.
231+
232+
Parameters
233+
----------
234+
condition: tvm.Object
235+
The assertion condition. Must be a boolean scalar.
236+
237+
format_str: str
238+
The last argument is a Python-style format string for printing the value
239+
240+
format_args: List[tvm.Object]
241+
Values used for formatting the string.
242+
"""
243+
if not isinstance(format_str, str):
244+
raise ValueError(
245+
f"The format string argument to assert must be a string, given {type(format_str)})"
246+
)
247+
248+
# should be guaranteed by the type system
249+
if not isinstance(condition, tvm.runtime.ndarray.NDArray):
250+
raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.")
251+
252+
# may happen if the original program had unknown shape or dtype for the tensor's type
253+
dtype = condition.dtype
254+
if dtype != "bool":
255+
raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor")
256+
shape = condition.shape
257+
if len(shape) != 0:
258+
raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}")
259+
260+
val = condition.numpy()
261+
if not val:
262+
error_message = "Assertion Failed"
263+
if format_args or format_str != "":
264+
rendered = map(render_object, format_args)
265+
if format_str != "":
266+
error_message = format_str.format(*rendered)
267+
else:
268+
error_message = ", ".join(rendered)
269+
raise AssertionError(error_message)
270+
271+
272+
def assert_op(condition: Expr, format_args: Optional[List[Expr]] = None, format: str = "") -> Expr:
273+
"""
274+
Create a call to Relax's assert_op operation (`assert` is reserved in Python,
275+
so the name must be distinct).
276+
277+
Parameters
278+
----------
279+
condition: Expr
280+
The assertion condition.
281+
282+
format_args: List[Expr]
283+
Format arguments for the error message if the condition fails.
284+
285+
format_str: str
286+
The format string for the error message.
287+
288+
Returns
289+
-------
290+
result : Expr
291+
A Call to the Relax assert operation.
292+
"""
293+
if format_args is None:
294+
format_args = []
295+
return _ffi_api.assert_op(condition, format_args, format) # type: ignore
296+
297+
217298
def shape_of(expr: Expr) -> Expr:
218299
"""Get shape of a tensor.
219300

python/tvm/relax/op/op_attrs.py

+5
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,8 @@ class UniqueAttrs(Attrs):
4242
@tvm._ffi.register_object("relax.attrs.PrintAttrs")
4343
class PrintAttrs(Attrs):
4444
"""Attributes used for the print operator"""
45+
46+
47+
@tvm._ffi.register_object("relax.attrs.AssertOpAttrs")
48+
class AssertOpAttrs(Attrs):
49+
"""Attributes used for the assert operator"""

src/relax/backend/vm/codegen_vm.cc

+9-1
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,14 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
385385
}
386386
if (call_node->op == print_op_) {
387387
auto print_attrs = call_node->attrs.as<PrintAttrs>();
388-
args.push_back(EmitConstantFromValue(print_attrs->format));
388+
// format string is the first argument
389+
args.insert(args.begin(), EmitConstantFromValue(print_attrs->format));
390+
return;
391+
}
392+
if (call_node->op == assert_op_) {
393+
auto assert_attrs = call_node->attrs.as<AssertOpAttrs>();
394+
// format string comes before the format args
395+
args.insert(args.begin() + 1, EmitConstantFromValue(assert_attrs->format));
389396
return;
390397
}
391398
LOG(FATAL) << "Support for attributes of Op " << call_node->op
@@ -520,6 +527,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
520527
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
521528
const Op& unique_op_ = Op::Get("relax.unique");
522529
const Op& print_op_ = Op::Get("relax.print");
530+
const Op& assert_op_ = Op::Get("relax.assert_op");
523531
const Op& make_closure_op_ = Op::Get("relax.make_closure");
524532
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
525533
};

src/relax/op/op.cc

+46
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <tvm/relax/attrs/memory.h>
2020
#include <tvm/relax/attrs/shape.h>
2121
#include <tvm/relax/expr.h>
22+
#include <tvm/relax/utils.h>
2223
#include <tvm/relay/op.h>
2324

2425
#include "op_common.h"
@@ -118,6 +119,51 @@ Expr MakePrint(Array<Expr> vals, std::string format) {
118119

119120
TVM_REGISTER_GLOBAL("relax.op.print").set_body_typed(MakePrint);
120121

122+
// assert_op
123+
124+
// can't actually name it assert or else Python will consider it a syntax error
125+
126+
Type InferAssertType(const Call& call, DiagnosticContext diag_ctx) {
127+
// Ensure that the condition argument is a boolean scalar.
128+
// Also permitted is a tensor with unknown shape and unknown dtype
129+
// (checked dynamically in that case). Returns void.
130+
if (call->args.size() < 1) {
131+
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
132+
<< "Assert must have at least one argument (the condition).");
133+
}
134+
Type arg_type = call->args[0]->checked_type();
135+
if (!IsBoolScalarType(arg_type)) {
136+
diag_ctx.EmitFatal(Diagnostic::Error(call->span)
137+
<< "The argument to assert must be a boolean scalar type, but received "
138+
<< arg_type);
139+
}
140+
return VoidType();
141+
}
142+
143+
TVM_REGISTER_NODE_TYPE(AssertOpAttrs);
144+
145+
RELAY_REGISTER_OP("relax.assert_op")
146+
.set_attrs_type<AssertOpAttrs>()
147+
.set_num_inputs(-1)
148+
.add_argument("vals", "Array<Expr>",
149+
"The first value is used as the assertion condition. The others are used as "
150+
"format arguments if there is an error.")
151+
.set_attr<FInferType>("FInferType", InferAssertType)
152+
.set_attr<FCallPacked>("FCallPacked", "relax.run.assert_op");
153+
154+
Expr MakeAssertOp(Expr condition, Array<Expr> vals, std::string format) {
155+
auto attrs = make_object<AssertOpAttrs>();
156+
attrs->format = format;
157+
static const Op& op = Op::Get("relax.assert_op");
158+
Array<Expr> args = {condition};
159+
for (auto val : vals) {
160+
args.push_back(val);
161+
}
162+
return Call(op, args, Attrs(attrs));
163+
}
164+
165+
TVM_REGISTER_GLOBAL("relax.op.assert_op").set_body_typed(MakeAssertOp);
166+
121167
// make_closure
122168

123169
RELAY_REGISTER_OP("relax.make_closure")

src/relax/utils.cc

+10
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,15 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
6767
}
6868
}
6969

70+
bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool permit_unknown_dtype) {
71+
const DynTensorTypeNode* tt = ty.as<DynTensorTypeNode>();
72+
if (!tt) {
73+
return false;
74+
}
75+
bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype && tt->dtype.is_void());
76+
bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1);
77+
return correct_dtype && correct_rank;
78+
}
79+
7080
} // namespace relax
7181
} // namespace tvm

tests/python/relax/test_relax_operators.py

+56
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222
import tvm
2323
from tvm import relax
24+
from tvm._ffi.base import TVMError
2425

2526
from tvm.script import relax as R
2627

@@ -88,5 +89,60 @@ def test_print():
8889
sys.stdout = stdout
8990

9091

92+
@tvm.script.ir_module
93+
class AssertOpTest:
94+
@R.function
95+
def passes(x: Tensor((), "int32")):
96+
p1 = relax.assert_op(relax.const(True))
97+
return x
98+
99+
@R.function
100+
def pass_with_args(x: Tensor((), "int32")):
101+
p1 = relax.assert_op(relax.const(True), x, format="You won't see me")
102+
return x
103+
104+
@R.function
105+
def simple_fail(x: Tensor((), "int32")):
106+
p1 = relax.assert_op(relax.const(False))
107+
return x
108+
109+
@R.function
110+
def fail_with_message(x: Tensor((), "int32")):
111+
p1 = relax.assert_op(relax.const(False), format="I failed...")
112+
return x
113+
114+
@R.function
115+
def fail_with_args(x: Tensor((), "int32")):
116+
# no format
117+
p1 = relax.assert_op(relax.const(False), x, x)
118+
return x
119+
120+
@R.function
121+
def fail_with_formatted_message(x: Tensor((), "int32")):
122+
p1 = relax.assert_op(relax.const(False), x, format="Number: {}")
123+
return x
124+
125+
126+
def test_assert_op():
127+
def check_assertion_error(func_name, func_arg, expected_message):
128+
passed = False
129+
try:
130+
run_cpu(AssertOpTest, func_name, func_arg)
131+
passed = True
132+
except TVMError as e:
133+
# TVM will print out a TVMError that will contain the
134+
# generated error at the bottom of a stack trace
135+
assert "AssertionError" in e.args[0]
136+
assert expected_message in e.args[0]
137+
assert not passed
138+
139+
run_cpu(AssertOpTest, "passes", tvm.nd.array(1))
140+
run_cpu(AssertOpTest, "pass_with_args", tvm.nd.array(2))
141+
check_assertion_error("simple_fail", tvm.nd.array(3), "Assertion Failed")
142+
check_assertion_error("fail_with_message", tvm.nd.array(4), "I failed...")
143+
check_assertion_error("fail_with_args", tvm.nd.array(5), "5, 5")
144+
check_assertion_error("fail_with_formatted_message", tvm.nd.array(6), "Number: 6")
145+
146+
91147
if __name__ == "__main__":
92148
pytest.main([__file__])

0 commit comments

Comments
 (0)