Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Commit ee49c24

Browse files
tqchenHzfengsyMasterJH5574
authored andcommitted
[REFACTOR][ARCH] Introduce StructInfo M0 (#314)
* [IR] Introduce StructInfo * StructInfoFunctor and Analysis Support * [TVMScript] Parse type/shape annotation with StructInfo * remove runtime type assign * Remove type/shape during parsing (#2) * Normalizer prep: simple checks and legacy function renaming. * Struct info deduction in BlockBuilder. * Two TODOs * StructInfo Normalizer Fixes (#3) * StructInfo AST Fix * Fix Extern Func Deduction and shape mutator. * Update VoidStructInfo & globalvar (#4) * Fix passes and proper sinfo propagation. * Refactor EraseToWellDefined to Enable Remapping * [WIP] First stab at symbolic param tracking * Update EraseToWellDefined to support symbolic shape return (#5) * fix R.shape with ndim (#6) * Remove update shape/type * Address review comment, AnnotateTypeShape=>AnnotateStructInfo * Update include/tvm/script/ir_builder/relax/frame.h Co-authored-by: Ruihang Lai <[email protected]> * Address comments * Update printer to use structinfo (#7) * Update Error mechanism to prep for obj loc based reporting * Symbolic shape aware function call return value derivation. The main flow works as follows: - Match and populate shape_var_map and var_map by visit each pair of param and call arguments. - Call EraseToWellDefined to map the ret parameter to new result. * [ANALYSIS] Refactor well-form to only look at struct info. * Update comments according to reviews. * Update include/tvm/relax/struct_info.h Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Tianqi Chen <tqchen> Co-authored-by: Ruihang Lai <[email protected]>
1 parent 02d4104 commit ee49c24

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+4971
-1289
lines changed

include/tvm/ir/diagnostic.h

+27
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ class DiagnosticNode : public Object {
5656
DiagnosticLevel level;
5757
/*! \brief The span at which to report an error. */
5858
Span span;
59+
/*!
60+
* \brief The object location at which to report an error.
61+
*
62+
* The object loc provides a location when span is not always
63+
* available during transformation. The error reporter can
64+
* still pick up loc->span if necessary.
65+
*/
66+
ObjectRef loc;
5967
/*! \brief The diagnostic message. */
6068
String message;
6169

@@ -84,6 +92,18 @@ class Diagnostic : public ObjectRef {
8492
static DiagnosticBuilder Warning(Span span);
8593
static DiagnosticBuilder Note(Span span);
8694
static DiagnosticBuilder Help(Span span);
95+
// variants uses object location
96+
static DiagnosticBuilder Bug(ObjectRef loc);
97+
static DiagnosticBuilder Error(ObjectRef loc);
98+
static DiagnosticBuilder Warning(ObjectRef loc);
99+
static DiagnosticBuilder Note(ObjectRef loc);
100+
static DiagnosticBuilder Help(ObjectRef loc);
101+
// variants uses object ptr.
102+
static DiagnosticBuilder Bug(const Object* loc);
103+
static DiagnosticBuilder Error(const Object* loc);
104+
static DiagnosticBuilder Warning(const Object* loc);
105+
static DiagnosticBuilder Note(const Object* loc);
106+
static DiagnosticBuilder Help(const Object* loc);
87107

88108
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Diagnostic, ObjectRef, DiagnosticNode);
89109
};
@@ -102,6 +122,11 @@ class DiagnosticBuilder {
102122
/*! \brief The span of the diagnostic. */
103123
Span span;
104124

125+
/*!
126+
* \brief The object location at which to report an error.
127+
*/
128+
ObjectRef loc;
129+
105130
template <typename T>
106131
DiagnosticBuilder& operator<<(const T& val) { // NOLINT(*)
107132
stream_ << val;
@@ -115,6 +140,8 @@ class DiagnosticBuilder {
115140

116141
DiagnosticBuilder(DiagnosticLevel level, Span span) : level(level), span(span) {}
117142

143+
DiagnosticBuilder(DiagnosticLevel level, ObjectRef loc) : level(level), loc(loc) {}
144+
118145
operator Diagnostic() { return Diagnostic(this->level, this->span, this->stream_.str()); }
119146

120147
private:

include/tvm/ir/expr.h

+8
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,13 @@ class RelayExprNode : public BaseExprNode {
378378
*/
379379
mutable Optional<ObjectRef> shape_ = Optional<ObjectRef>();
380380

381+
/*!
382+
* \brief Stores the result of structure information of the
383+
* expression that encapsulate both static shape and
384+
* runtime information such as shape.
385+
*/
386+
mutable Optional<ObjectRef> struct_info_ = Optional<ObjectRef>();
387+
381388
/*!
382389
* \return The checked_type
383390
*/
@@ -473,6 +480,7 @@ class GlobalVarNode : public RelayExprNode {
473480
v->Visit("virtual_device_", &virtual_device_);
474481
v->Visit("span", &span);
475482
v->Visit("_checked_type_", &checked_type_);
483+
v->Visit("struct_info_", &struct_info_);
476484
}
477485

478486
bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {

include/tvm/ir/type.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class PrimType : public Type {
132132
* \brief Constructor
133133
* \param dtype The corresponding dtype.
134134
*/
135-
TVM_DLL explicit PrimType(runtime::DataType dtype);
135+
TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span());
136136

137137
TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
138138
};

include/tvm/relax/analysis.h

+245-1
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,266 @@
1919

2020
/*!
2121
* \file tvm/relax/analysis.h
22-
* \brief The set of Relax specific analysis passes.
22+
* \brief The set of Relax specific analysis on IR.
2323
*/
2424
#ifndef TVM_RELAX_ANALYSIS_H_
2525
#define TVM_RELAX_ANALYSIS_H_
2626

27+
#include <tvm/arith/analyzer.h>
2728
#include <tvm/ir/diagnostic.h>
2829
#include <tvm/ir/module.h>
2930
#include <tvm/relax/expr.h>
31+
#include <tvm/relax/struct_info.h>
3032
#include <tvm/relay/op_attr_types.h>
3133
#include <tvm/tir/function.h>
3234

35+
#include <functional>
3336
#include <utility>
3437

3538
namespace tvm {
3639
namespace relax {
40+
//-----------------------------------
41+
// Shape expression analysis
42+
//----------------------------------
43+
/*!
44+
* \brief Can prove the two symbolic shape arrays equals to each other.
45+
*
46+
* \param lhs The left operand.
47+
* \param rhs The right operand.
48+
* \param ana The analyzer used for integer analysis.
49+
* \return The prove result.
50+
*
51+
* \note This function does best effort prove, which means
52+
* if result is false, there is still possibility that
53+
* two shapes equals to each other during runtime.
54+
*/
55+
TVM_DLL bool CanProveShapeEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs,
56+
arith::Analyzer* ana);
57+
58+
/*!
59+
* \brief Can prove the two symbolic shape expressions equals to each other.
60+
*
61+
* \param lhs The left operand.
62+
* \param rhs The right operand.
63+
* \param ana The analyzer used for integer analysis.
64+
*
65+
* \note This function does best effort prove, which means
66+
* if result is false, there is still possibility that
67+
* two shapes equals to each other during runtime.
68+
*/
69+
TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana);
70+
71+
//-----------------------------------
72+
// Foundational StructInfo analysis
73+
//-----------------------------------
74+
/*!
75+
* \brief Get the corresponding static type from a given struct info.
76+
* \param info The struct info.
77+
* \return the corresponding static type.
78+
*/
79+
TVM_DLL Type GetStaticType(const StructInfo& info);
80+
81+
/*!
82+
* \brief Get the corresponding struct info from static type.
83+
* \param type The input type
84+
* \return the corresponding struct info.
85+
*/
86+
TVM_DLL StructInfo StructInfoFromType(const Type& type);
87+
88+
// TODO(relax-team): Remove legacy shape related functionalities after phasing out shape_
89+
/*!
90+
* \brief Get the corresponding struct info from static type.
91+
* \param type The input type
92+
* \param shape_hint The shape hint
93+
* \return the corresponding struct info.
94+
*/
95+
TVM_DLL StructInfo StructInfoFromTypeLegacyShapeHint(const Type& type, Optional<Expr> shape_hint);
96+
97+
/*!
98+
* \brief Get the corresponding legacy shape hint from struct info
99+
* \param info The struct info.
100+
* \return the corresponding legacy shape hint.
101+
*/
102+
TVM_DLL Optional<Expr> GetLegacyShapeHint(const StructInfo& info);
103+
104+
/*!
105+
* \return Derive the call's ret value struct info from inputs.
106+
* \param func_info The function struct info.
107+
* \param call The call expression to be derived.
108+
* \param ctx The builder context.
109+
* \param ana Optional context analyzer to prove symbolic expression equality.
110+
* \return The derived struct info of the call.
111+
* \note call->op field is ignored during derivation and we only rely on information
112+
* presented by func_sinfo.
113+
*/
114+
TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call,
115+
const BlockBuilder& ctx, arith::Analyzer* ana = nullptr);
116+
117+
/*!
118+
* \brief Erase the info to a corresponding more coarse grained
119+
* struct info that is still well-defined(with all the vars in scope).
120+
*
121+
* When we are returning a StructInfo to another scope,
122+
* it is important to remember that StructInfo may carry
123+
* dependencies on var that is not defined the other scope.
124+
*
125+
* In such cases, it is important to call EraseToWellDefined to get
126+
* another StructInfo that **only** contains the vars that are defined
127+
* in the target scope.
128+
*
129+
* For example, consider the following function
130+
*
131+
* \code
132+
*
133+
* @R.function
134+
* def f(x: R.Tensor[(n, m)]):
135+
* k = tir.Var("k", "int64")
136+
* v0 = opaque_fn(x)
137+
* v1 = match_cast(v0, R.Tensor[(n, k)])
138+
* v2 : R.Tensor[(n + 1, k + 2)] = pad(v1)
139+
* return v2
140+
*
141+
* \endcode
142+
*
143+
* In the above code, the return value y have shape `(n + 1, k + 2)`,
144+
* However, at the level of function signature, only n, m are defined,
145+
* k is undefined here.
146+
*
147+
* When we call EraseToWellDefined(R.Tensor[(n + 1, k + 2)], fshape_var_map={n: n, m: m}),
148+
* we will obtain R.Tensor(ndim=2), which is an erased info that does not depend
149+
* on k(which is undefined from parameter signature).
150+
*
151+
* However, if we call EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: n, m: m}),
152+
* Then the return value will be R.Tensor[(n + 1, m)], because both n and m are defined.
153+
*
154+
* We can also make these var map to return a different expression.
155+
* For example, EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: 2, m: m})
156+
* will give us R.Tensor[(3, m)], where n get replaced by 2.
157+
*
158+
* Use this function in the following scenarios:
159+
* - Decide the struct_info of expr with sub-scopes, such as If, SeqExpr
160+
* - Decide the deduced return struct_info of a function that can be fully decided by params.
161+
*
162+
* \param info The struct info.
163+
* \param f_shape_var_map callback function to specify
164+
* whether a symbolic shape var is defined and the value it maps to,
165+
* return nullopt if var is undefined.
166+
* \param f_var_defined callback function to specify
167+
* whether a var is defined in the target scope and the value it maps to,
168+
* return nullopt if var is undefined.
169+
* \param ana Optional context analyzer to prove symbolic expression equality.
170+
*
171+
* \return the corresponding erased struct info.
172+
*/
173+
TVM_DLL StructInfo
174+
EraseToWellDefined(const StructInfo& info,
175+
std::function<Optional<PrimExpr>(const tir::Var& var)> f_shape_var_map = nullptr,
176+
std::function<Optional<Expr>(const Var& var)> f_var_map = nullptr,
177+
arith::Analyzer* ana = nullptr);
178+
179+
/*!
180+
* \brief EraseToWellDefined variant with map.
181+
* \param info The struct info.
182+
* \param f_shape_var_map callback function to specify
183+
* whether a symbolic shape var is defined and the value it maps to,
184+
* return nullopt if var is undefined.
185+
* \param f_var_defined callback function to specify
186+
* whether a var is defined in the target scope and the value it maps to,
187+
* return nullopt if var is undefined.
188+
* \param ana Optional context analyzer to prove symbolic expression equality.
189+
*
190+
* \return the corresponding erased struct info.
191+
*/
192+
TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map<tir::Var, PrimExpr> shape_var_map,
193+
Map<Var, Expr> var_map, arith::Analyzer* ana = nullptr);
194+
195+
/*!
196+
* \brief Fine grained result of base check.
197+
*
198+
* This analysis comes with different levels of checking failures
199+
* that can help to customize the compilation decisions.
200+
*
201+
* For a given pair of lhs_struct_info, rhs_struct_info. We adopt
202+
* the following terminology:
203+
* - LSet = {value | value mactches lhs_struct_info}
204+
* - RSet = {value | value mactches rhs_struct_info}
205+
*
206+
* See the definition of each level below.
207+
*/
208+
enum class BaseCheckResult {
209+
/*!
210+
* \brief The two value sets have no intersection at all: Interset(LSet, RSet) = empty
211+
*/
212+
kFailL0 = 0,
213+
/*!
214+
* \brief LSet is not superset of RSet by only looking at static information.
215+
*
216+
* \note This level will trigger static type checking error when lhs is param and rhs is arg.
217+
*/
218+
kFailL1 = 1,
219+
/*!
220+
* \brief WLSet is not superset of RSet because of mismatch in value information.
221+
*
222+
* L1-level mismatches in params of FuncStructInfo is categorized as
223+
* If lhs is FuncStructInfo, then L1-level mismatch in its params
224+
* is categorized as L2-level mismatch for lhs.
225+
*
226+
* Design considerations for functions:
227+
* - (a) We want to be able to erase type/value in function signature
228+
* when we unify function struct info and preserve simpler representations.
229+
* - (b) We automatically insert match_cast at function boundary, so
230+
* we can erase (int)->int argument as (object)->int.
231+
* The input shape/type mismatch will be detected by runtime checks at function boundary.
232+
* This behavior is also consistent with the PackedFunc behavior.
233+
*
234+
* \note This level means there is no problem about static known information.
235+
* It is OK for the checker to do best effort and return this value.
236+
*/
237+
kFailL2 = 2,
238+
/*! \brief LSet is superset of RSet. */
239+
kPass = 3
240+
};
241+
242+
/*!
243+
* \brief Run a base check to see if base subsumes derived.
244+
*
245+
* This function returns fine-grained base-check result on reasons of failure.
246+
*
247+
* \param base The base struct info.
248+
* \param derived The derived struct info.
249+
* \param ana Optional context analyzer to prove symbolic expression equality.
250+
* \return Whether the relation holds.
251+
*
252+
* \sa BaseCheckResult
253+
*/
254+
TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived,
255+
arith::Analyzer* ana = nullptr);
256+
257+
/*!
258+
* \brief Check the relation of two struct info to see if one subsumes another one.
259+
*
260+
* \param base The base struct info.
261+
* \param derived The derived struct info.
262+
* \param ana Optional context analyzer to prove symbolic expression equality.
263+
* \return Whether the relation holds.
264+
*/
265+
TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
266+
arith::Analyzer* ana = nullptr);
267+
268+
/*!
269+
* \brief Unify the two struct info their least common ancestor.
270+
*
271+
* \param lhs The left operand.
272+
* \param rhs The right operand.
273+
* \param ana Optional context analyzer to prove symbolic expression equality.
274+
* \return The unified information.
275+
*/
276+
TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
277+
arith::Analyzer* ana = nullptr);
37278

279+
//-----------------------------------
280+
// General IR analysis
281+
//----------------------------------
38282
/*!
39283
* \brief Check if the IRModule is well formed.
40284
*

0 commit comments

Comments
 (0)