@@ -140,53 +140,55 @@ def invoke_closure(
140
140
return _ffi_api .invoke_closure (closure , args )
141
141
142
142
143
+ def render_object (val : tvm .Object ) -> str :
144
+ """
145
+ Given a TVM Object, renders it in string form. Used for Relax printing and assertions.
146
+
147
+ Parameters
148
+ ----------
149
+ val: tvm.Object
150
+ An object to render
151
+
152
+ Returns
153
+ -------
154
+ ret: str
155
+ A string representing the value, ideally human-readable
156
+ """
157
+ if isinstance (val , tvm .runtime .ndarray .NDArray ):
158
+ return str (val )
159
+ # no pretty-printer by default, so if we don't handle this,
160
+ # then we can't look inside tuples
161
+ if isinstance (val , tvm .runtime .container .ADT ):
162
+ # the fields array of an ADT cannot be directly accessed in Python
163
+ # so we have to get the length and index into the fields separately
164
+ fields = ", " .join ([render_object (val [i ]) for i in range (len (val ))])
165
+ # special case: tag = 0 is a tuple
166
+ if val .tag == 0 :
167
+ return f"({ fields } )"
168
+ return f"ADT(tag={ val .tag } , fields=[{ fields } ])"
169
+ return str (val )
170
+
171
+
143
172
@tvm .register_func ("relax.run.print" )
144
- def relax_print (* args : List [ any ] ) -> None :
173
+ def relax_print (format_str : str , * format_args : tvm . Object ) -> None :
145
174
"""
146
175
Takes a list of values to print, formats with the given format string.
147
176
If the format string is empty, simply prints.
148
177
149
- Since this function is called as a PackedFunc from the generated code,
150
- we cannot have it be variadic _and_ have an optional format string attribute
151
- except by taking in all the arguments as a single list. The last argument
152
- should be a format string.
153
-
154
178
Call from TVM script like this:
155
179
`relax.print(value1, value2, ..., valueN, format=format_str)`
156
180
or
157
181
`relax.print(value1, value2, ..., valueN) # format_str defaults to ""`
158
182
159
183
Parameters
160
184
----------
161
- vals: List[Object]
162
- The values to print.
163
-
164
185
format_str: str
165
186
The last argument is a Python-style format string for printing the value
166
- """
167
-
168
- # there is no way to have a keyword arg to a packed function,
169
- # so the format string is always the last argument
170
- format_str = args [- 1 ]
171
- if not isinstance (format_str , str ):
172
- raise ValueError ("No valid format string given." )
173
-
174
- def render (val : tvm .Object ) -> str :
175
- if isinstance (val , tvm .runtime .ndarray .NDArray ):
176
- return str (val )
177
- # no pretty-printer by default, so if we don't handle this,
178
- # then we can't look inside tuples
179
- if isinstance (val , tvm .runtime .container .ADT ):
180
- # the fields array of an ADT cannot be directly accessed in Python
181
- # so we have to get the length and index into the fields separately
182
- fields = ", " .join ([render (val [i ]) for i in range (len (val ))])
183
- # special case: tag = 0 is a tuple
184
- if val .tag == 0 :
185
- return f"({ fields } )"
186
- return f"ADT(tag={ val .tag } , fields=[{ fields } ])"
187
- return str (val )
188
187
189
- val_strs = map (render , args [:- 1 ])
188
+ format_args: List[Object]
189
+ The values to print.
190
+ """
191
+ val_strs = map (render_object , format_args )
190
192
if format_str == "" :
191
193
py_print (* val_strs )
192
194
else :
@@ -214,6 +216,85 @@ def print(values: Union[Expr, List[Expr]], format: str) -> Expr:
214
216
return _ffi_api .print (values , format ) # type: ignore # pylint: disable=no-member
215
217
216
218
219
+ @tvm .register_func ("relax.run.assert_op" )
220
+ def relax_assert_op (condition : tvm .Object , format_str : str , * format_args : tvm .Object ) -> None :
221
+ """
222
+ A variadic function. The first value serves as the assertion condition:
223
+ If the condition is true, then the operator does nothing.
224
+ If the condition is false, then the operator raises an assertion error.
225
+
226
+ Arguments after the first value serve as format arguments for the error message;
227
+ the last argument must be a format string for the error message (empty by default).
228
+ If the format string is the empty string, then the error message will simply include
229
+ a comma-separated list of the format arguments.
230
+ The condition argument is not included in the format string.
231
+
232
+ Parameters
233
+ ----------
234
+ condition: tvm.Object
235
+ The assertion condition. Must be a boolean scalar.
236
+
237
+ format_str: str
238
+ The last argument is a Python-style format string for printing the value
239
+
240
+ format_args: List[tvm.Object]
241
+ Values used for formatting the string.
242
+ """
243
+ if not isinstance (format_str , str ):
244
+ raise ValueError (
245
+ f"The format string argument to assert must be a string, given { type (format_str )} )"
246
+ )
247
+
248
+ # should be guaranteed by the type system
249
+ if not isinstance (condition , tvm .runtime .ndarray .NDArray ):
250
+ raise ValueError (f"The condition must be an NDArray, but given a { type (condition )} ." )
251
+
252
+ # may happen if the original program had unknown shape or dtype for the tensor's type
253
+ dtype = condition .dtype
254
+ if dtype != "bool" :
255
+ raise ValueError (f"The condition must be a bool scalar, but given a { dtype } tensor" )
256
+ shape = condition .shape
257
+ if len (shape ) != 0 :
258
+ raise ValueError (f"The condition must be a scalar, but it has a shape of { shape } " )
259
+
260
+ val = condition .numpy ()
261
+ if not val :
262
+ error_message = "Assertion Failed"
263
+ if format_args or format_str != "" :
264
+ rendered = map (render_object , format_args )
265
+ if format_str != "" :
266
+ error_message = format_str .format (* rendered )
267
+ else :
268
+ error_message = ", " .join (rendered )
269
+ raise AssertionError (error_message )
270
+
271
+
272
+ def assert_op (condition : Expr , format_args : Optional [List [Expr ]] = None , format : str = "" ) -> Expr :
273
+ """
274
+ Create a call to Relax's assert_op operation (`assert` is reserved in Python,
275
+ so the name must be distinct).
276
+
277
+ Parameters
278
+ ----------
279
+ condition: Expr
280
+ The assertion condition.
281
+
282
+ format_args: List[Expr]
283
+ Format arguments for the error message if the condition fails.
284
+
285
+ format_str: str
286
+ The format string for the error message.
287
+
288
+ Returns
289
+ -------
290
+ result : Expr
291
+ A Call to the Relax assert operation.
292
+ """
293
+ if format_args is None :
294
+ format_args = []
295
+ return _ffi_api .assert_op (condition , format_args , format ) # type: ignore
296
+
297
+
217
298
def shape_of (expr : Expr ) -> Expr :
218
299
"""Get shape of a tensor.
219
300
0 commit comments