|
19 | 19 |
|
20 | 20 | /*!
|
21 | 21 | * \file tvm/relax/analysis.h
|
22 |
| - * \brief The set of Relax specific analysis passes. |
| 22 | + * \brief The set of Relax specific analysis on IR. |
23 | 23 | */
|
24 | 24 | #ifndef TVM_RELAX_ANALYSIS_H_
|
25 | 25 | #define TVM_RELAX_ANALYSIS_H_
|
26 | 26 |
|
| 27 | +#include <tvm/arith/analyzer.h> |
27 | 28 | #include <tvm/ir/diagnostic.h>
|
28 | 29 | #include <tvm/ir/module.h>
|
29 | 30 | #include <tvm/relax/expr.h>
|
| 31 | +#include <tvm/relax/struct_info.h> |
30 | 32 | #include <tvm/relay/op_attr_types.h>
|
31 | 33 | #include <tvm/tir/function.h>
|
32 | 34 |
|
| 35 | +#include <functional> |
33 | 36 | #include <utility>
|
34 | 37 |
|
35 | 38 | namespace tvm {
|
36 | 39 | namespace relax {
|
| 40 | +//----------------------------------- |
| 41 | +// Shape expression analysis |
| 42 | +//---------------------------------- |
| 43 | +/*! |
| 44 | + * \brief Can prove the two symbolic shape arrays equals to each other. |
| 45 | + * |
| 46 | + * \param lhs The left operand. |
| 47 | + * \param rhs The right operand. |
| 48 | + * \param ana The analyzer used for integer analysis. |
| 49 | + * \return The prove result. |
| 50 | + * |
| 51 | + * \note This function does best effort prove, which means |
| 52 | + * if result is false, there is still possibility that |
| 53 | + * two shapes equals to each other during runtime. |
| 54 | + */ |
| 55 | +TVM_DLL bool CanProveShapeEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs, |
| 56 | + arith::Analyzer* ana); |
| 57 | + |
| 58 | +/*! |
| 59 | + * \brief Can prove the two symbolic shape expressions equals to each other. |
| 60 | + * |
| 61 | + * \param lhs The left operand. |
| 62 | + * \param rhs The right operand. |
| 63 | + * \param ana The analyzer used for integer analysis. |
| 64 | + * |
| 65 | + * \note This function does best effort prove, which means |
| 66 | + * if result is false, there is still possibility that |
| 67 | + * two shapes equals to each other during runtime. |
| 68 | + */ |
| 69 | +TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana); |
| 70 | + |
| 71 | +//----------------------------------- |
| 72 | +// Foundational StructInfo analysis |
| 73 | +//----------------------------------- |
| 74 | +/*! |
| 75 | + * \brief Get the corresponding static type from a given struct info. |
| 76 | + * \param info The struct info. |
| 77 | + * \return the corresponding static type. |
| 78 | + */ |
| 79 | +TVM_DLL Type GetStaticType(const StructInfo& info); |
| 80 | + |
| 81 | +/*! |
| 82 | + * \brief Get the corresponding struct info from static type. |
| 83 | + * \param type The input type |
| 84 | + * \return the corresponding struct info. |
| 85 | + */ |
| 86 | +TVM_DLL StructInfo StructInfoFromType(const Type& type); |
| 87 | + |
| 88 | +// TODO(relax-team): Remove legacy shape related functionalities after phasing out shape_ |
| 89 | +/*! |
| 90 | + * \brief Get the corresponding struct info from static type. |
| 91 | + * \param type The input type |
| 92 | + * \param shape_hint The shape hint |
| 93 | + * \return the corresponding struct info. |
| 94 | + */ |
| 95 | +TVM_DLL StructInfo StructInfoFromTypeLegacyShapeHint(const Type& type, Optional<Expr> shape_hint); |
| 96 | + |
| 97 | +/*! |
| 98 | + * \brief Get the corresponding legacy shape hint from struct info |
| 99 | + * \param info The struct info. |
| 100 | + * \return the corresponding legacy shape hint. |
| 101 | + */ |
| 102 | +TVM_DLL Optional<Expr> GetLegacyShapeHint(const StructInfo& info); |
| 103 | + |
| 104 | +/*! |
| 105 | + * \return Derive the call's ret value struct info from inputs. |
| 106 | + * \param func_info The function struct info. |
| 107 | + * \param call The call expression to be derived. |
| 108 | + * \param ctx The builder context. |
| 109 | + * \param ana Optional context analyzer to prove symbolic expression equality. |
| 110 | + * \return The derived struct info of the call. |
| 111 | + * \note call->op field is ignored during derivation and we only rely on information |
| 112 | + * presented by func_sinfo. |
| 113 | + */ |
| 114 | +TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, |
| 115 | + const BlockBuilder& ctx, arith::Analyzer* ana = nullptr); |
| 116 | + |
| 117 | +/*! |
| 118 | + * \brief Erase the info to a corresponding more coarse grained |
| 119 | + * struct info that is still well-defined(with all the vars in scope). |
| 120 | + * |
| 121 | + * When we are returning a StructInfo to another scope, |
| 122 | + * it is important to remember that StructInfo may carry |
| 123 | + * dependencies on var that is not defined the other scope. |
| 124 | + * |
| 125 | + * In such cases, it is important to call EraseToWellDefined to get |
| 126 | + * another StructInfo that **only** contains the vars that are defined |
| 127 | + * in the target scope. |
| 128 | + * |
| 129 | + * For example, consider the following function |
| 130 | + * |
| 131 | + * \code |
| 132 | + * |
| 133 | + * @R.function |
| 134 | + * def f(x: R.Tensor[(n, m)]): |
| 135 | + * k = tir.Var("k", "int64") |
| 136 | + * v0 = opaque_fn(x) |
| 137 | + * v1 = match_cast(v0, R.Tensor[(n, k)]) |
| 138 | + * v2 : R.Tensor[(n + 1, k + 2)] = pad(v1) |
| 139 | + * return v2 |
| 140 | + * |
| 141 | + * \endcode |
| 142 | + * |
| 143 | + * In the above code, the return value y have shape `(n + 1, k + 2)`, |
| 144 | + * However, at the level of function signature, only n, m are defined, |
| 145 | + * k is undefined here. |
| 146 | + * |
| 147 | + * When we call EraseToWellDefined(R.Tensor[(n + 1, k + 2)], fshape_var_map={n: n, m: m}), |
| 148 | + * we will obtain R.Tensor(ndim=2), which is an erased info that does not depend |
| 149 | + * on k(which is undefined from parameter signature). |
| 150 | + * |
| 151 | + * However, if we call EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: n, m: m}), |
| 152 | + * Then the return value will be R.Tensor[(n + 1, m)], because both n and m are defined. |
| 153 | + * |
| 154 | + * We can also make these var map to return a different expression. |
| 155 | + * For example, EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: 2, m: m}) |
| 156 | + * will give us R.Tensor[(3, m)], where n get replaced by 2. |
| 157 | + * |
| 158 | + * Use this function in the following scenarios: |
| 159 | + * - Decide the struct_info of expr with sub-scopes, such as If, SeqExpr |
| 160 | + * - Decide the deduced return struct_info of a function that can be fully decided by params. |
| 161 | + * |
| 162 | + * \param info The struct info. |
| 163 | + * \param f_shape_var_map callback function to specify |
| 164 | + * whether a symbolic shape var is defined and the value it maps to, |
| 165 | + * return nullopt if var is undefined. |
| 166 | + * \param f_var_defined callback function to specify |
| 167 | + * whether a var is defined in the target scope and the value it maps to, |
| 168 | + * return nullopt if var is undefined. |
| 169 | + * \param ana Optional context analyzer to prove symbolic expression equality. |
| 170 | + * |
| 171 | + * \return the corresponding erased struct info. |
| 172 | + */ |
| 173 | +TVM_DLL StructInfo |
| 174 | +EraseToWellDefined(const StructInfo& info, |
| 175 | + std::function<Optional<PrimExpr>(const tir::Var& var)> f_shape_var_map = nullptr, |
| 176 | + std::function<Optional<Expr>(const Var& var)> f_var_map = nullptr, |
| 177 | + arith::Analyzer* ana = nullptr); |
| 178 | + |
| 179 | +/*! |
| 180 | + * \brief EraseToWellDefined variant with map. |
| 181 | + * \param info The struct info. |
| 182 | + * \param f_shape_var_map callback function to specify |
| 183 | + * whether a symbolic shape var is defined and the value it maps to, |
| 184 | + * return nullopt if var is undefined. |
| 185 | + * \param f_var_defined callback function to specify |
| 186 | + * whether a var is defined in the target scope and the value it maps to, |
| 187 | + * return nullopt if var is undefined. |
| 188 | + * \param ana Optional context analyzer to prove symbolic expression equality. |
| 189 | + * |
| 190 | + * \return the corresponding erased struct info. |
| 191 | + */ |
| 192 | +TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map<tir::Var, PrimExpr> shape_var_map, |
| 193 | + Map<Var, Expr> var_map, arith::Analyzer* ana = nullptr); |
| 194 | + |
| 195 | +/*! |
| 196 | + * \brief Fine grained result of base check. |
| 197 | + * |
| 198 | + * This analysis comes with different levels of checking failures |
| 199 | + * that can help to customize the compilation decisions. |
| 200 | + * |
| 201 | + * For a given pair of lhs_struct_info, rhs_struct_info. We adopt |
| 202 | + * the following terminology: |
| 203 | + * - LSet = {value | value mactches lhs_struct_info} |
| 204 | + * - RSet = {value | value mactches rhs_struct_info} |
| 205 | + * |
| 206 | + * See the definition of each level below. |
| 207 | + */ |
| 208 | +enum class BaseCheckResult { |
| 209 | + /*! |
| 210 | + * \brief The two value sets have no intersection at all: Interset(LSet, RSet) = empty |
| 211 | + */ |
| 212 | + kFailL0 = 0, |
| 213 | + /*! |
| 214 | + * \brief LSet is not superset of RSet by only looking at static information. |
| 215 | + * |
| 216 | + * \note This level will trigger static type checking error when lhs is param and rhs is arg. |
| 217 | + */ |
| 218 | + kFailL1 = 1, |
| 219 | + /*! |
| 220 | + * \brief WLSet is not superset of RSet because of mismatch in value information. |
| 221 | + * |
| 222 | + * L1-level mismatches in params of FuncStructInfo is categorized as |
| 223 | + * If lhs is FuncStructInfo, then L1-level mismatch in its params |
| 224 | + * is categorized as L2-level mismatch for lhs. |
| 225 | + * |
| 226 | + * Design considerations for functions: |
| 227 | + * - (a) We want to be able to erase type/value in function signature |
| 228 | + * when we unify function struct info and preserve simpler representations. |
| 229 | + * - (b) We automatically insert match_cast at function boundary, so |
| 230 | + * we can erase (int)->int argument as (object)->int. |
| 231 | + * The input shape/type mismatch will be detected by runtime checks at function boundary. |
| 232 | + * This behavior is also consistent with the PackedFunc behavior. |
| 233 | + * |
| 234 | + * \note This level means there is no problem about static known information. |
| 235 | + * It is OK for the checker to do best effort and return this value. |
| 236 | + */ |
| 237 | + kFailL2 = 2, |
| 238 | + /*! \brief LSet is superset of RSet. */ |
| 239 | + kPass = 3 |
| 240 | +}; |
| 241 | + |
| 242 | +/*! |
| 243 | + * \brief Run a base check to see if base subsumes derived. |
| 244 | + * |
| 245 | + * This function returns fine-grained base-check result on reasons of failure. |
| 246 | + * |
| 247 | + * \param base The base struct info. |
| 248 | + * \param derived The derived struct info. |
| 249 | + * \param ana Optional context analyzer to prove symbolic expression equality. |
| 250 | + * \return Whether the relation holds. |
| 251 | + * |
| 252 | + * \sa BaseCheckResult |
| 253 | + */ |
| 254 | +TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, |
| 255 | + arith::Analyzer* ana = nullptr); |
| 256 | + |
| 257 | +/*! |
| 258 | + * \brief Check the relation of two struct info to see if one subsumes another one. |
| 259 | + * |
| 260 | + * \param base The base struct info. |
| 261 | + * \param derived The derived struct info. |
| 262 | + * \param ana Optional context analyzer to prove symbolic expression equality. |
| 263 | + * \return Whether the relation holds. |
| 264 | + */ |
| 265 | +TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, |
| 266 | + arith::Analyzer* ana = nullptr); |
| 267 | + |
| 268 | +/*! |
| 269 | + * \brief Unify the two struct info their least common ancestor. |
| 270 | + * |
| 271 | + * \param lhs The left operand. |
| 272 | + * \param rhs The right operand. |
| 273 | + * \param ana Optional context analyzer to prove symbolic expression equality. |
| 274 | + * \return The unified information. |
| 275 | + */ |
| 276 | +TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, |
| 277 | + arith::Analyzer* ana = nullptr); |
37 | 278 |
|
| 279 | +//----------------------------------- |
| 280 | +// General IR analysis |
| 281 | +//---------------------------------- |
38 | 282 | /*!
|
39 | 283 | * \brief Check if the IRModule is well formed.
|
40 | 284 | *
|
|
0 commit comments