@@ -2040,100 +2040,106 @@ bool CheckSameArray(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
2040
2040
return true ;
2041
2041
}
2042
2042
2043
+ bool CheckElemwisePattern (const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r) {
2044
+ if (indices_l.size () != indices_r.size ()) {
2045
+ return false ;
2046
+ }
2047
+ int n = indices_l.size ();
2048
+ for (int i = 0 ; i < n; i++) {
2049
+ if (!indices_l[i].same_as (indices_r[i])) {
2050
+ return false ;
2051
+ }
2052
+ }
2053
+ return true ;
2054
+ }
2055
+
2056
+ bool CheckBroadcastPattern (const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
2057
+ if (indices_l.size () < indices_r.size ()) {
2058
+ return false ;
2059
+ }
2060
+ int j=0 ;
2061
+ for (int i = 0 ; i < static_cast <int >(indices_r.size ()); i++) {
2062
+ for (; j < static_cast <int >(indices_l.size ()) && !indices_l[j].same_as
2063
+ (indices_r[i]); j++);
2064
+ if (j==static_cast <int >(indices_l.size ())){
2065
+ return false ;
2066
+ }
2067
+ }
2068
+ return true ;
2069
+ }
2070
+
2071
+ bool CheckInjectivePattern (const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
2072
+ std::unordered_set<const VarNode*> vars;
2073
+ for (int i = 0 ; i < static_cast <int >(indices_l.size ()); i++) {
2074
+ if (const auto * v = indices_l[i].as <VarNode>()) {
2075
+ vars.insert (v);
2076
+ } else {
2077
+ return false ;
2078
+ }
2079
+ }
2080
+ for (int i = 0 ; i < static_cast <int >(indices_r.size ()); i++) {
2081
+ if (tir::UsesVar (indices_r[i],
2082
+ [&vars](const VarNode* var) { return !vars.count (var); })) {
2083
+ return false ;
2084
+ }
2085
+ }
2086
+ return true ;
2087
+ }
2088
+
2043
2089
class PatternKindAnalyzer : public StmtExprVisitor {
2044
2090
void VisitStmt_ (const BufferStoreNode* op) final {
2045
- indices_. push_back ( op->indices ) ;
2091
+ store_indices_ = op->indices ;
2046
2092
StmtVisitor::VisitStmt_ (op);
2047
2093
}
2048
2094
void VisitExpr_ (const BufferLoadNode* op) final {
2049
- indices_ .push_back (op->indices );
2095
+ load_indices_ .push_back (op->indices );
2050
2096
ExprVisitor::VisitExpr_ (op);
2051
2097
}
2052
2098
2053
- void VisitExpr_ (const CallNode* op) final {
2054
- kind_=relay::kOpaque ;
2055
- }
2056
-
2057
2099
void VisitStmt_ (const BlockNode* op)final {
2058
2100
if (op->name_hint == " root" ) {
2059
2101
StmtVisitor::VisitStmt (op->body );
2060
2102
return ;
2061
2103
}
2062
-
2063
- relay::OpPatternKind kind = relay::kOpaque ;
2064
2104
2065
- // test whether is elemwise
2066
- indices_ .clear ();
2105
+ load_indices_. clear ();
2106
+ store_indices_ .clear ();
2067
2107
StmtVisitor::VisitStmt (op->body );
2068
- bool same_index = true ;
2069
- for (int i = 1 ; i < static_cast <int >(indices_.size ()); i++) {
2070
- if (!CheckSameArray (indices_[0 ],indices_[i])) {
2071
- same_index = false ;
2072
- break ;
2108
+
2109
+ relay::OpPatternKind index_pair_pattern = relay::kElemWise ;
2110
+ if (load_indices_.empty ()) {
2111
+ index_pair_pattern = relay::kBroadcast ;
2112
+ } else {
2113
+ for (int i = 0 ; i < static_cast <int >(load_indices_.size ()); i++) {
2114
+ if (CheckElemwisePattern (store_indices_, load_indices_[i])) {
2115
+ index_pair_pattern = std::max (index_pair_pattern, relay::kElemWise );
2116
+ } else if (CheckBroadcastPattern (store_indices_, load_indices_[i])) {
2117
+ index_pair_pattern = std::max (index_pair_pattern, relay::kBroadcast );
2118
+ } else if (CheckInjectivePattern (store_indices_, load_indices_[i])) {
2119
+ index_pair_pattern = std::max (index_pair_pattern, relay::kInjective );
2120
+ } else {
2121
+ index_pair_pattern = relay::kOpaque ;
2122
+ break ;
2123
+ }
2073
2124
}
2074
2125
}
2075
- if (same_index) {
2076
- kind = relay::kElemWise ;
2077
- kind_ = static_cast <int >(kind)>static_cast <int >(kind_)?kind:kind_;
2126
+ if (index_pair_pattern != relay::kOpaque ) {
2127
+ kind_ = std::max (kind_, index_pair_pattern);
2078
2128
return ;
2079
2129
}
2080
2130
2081
- if (const auto * store = op->body .as <BufferStoreNode>()) {
2082
- if (const auto * load = store->value .as <BufferLoadNode>()) {
2083
- // test whether is broadcast
2084
- int j = 0 ;
2085
- bool all_var_axis = true ;
2086
- for (int i = 0 ; i < static_cast <int >(load->indices .size ()); i++) {
2087
- if (load->indices [i].as <VarNode>()) {
2088
- for (; j < static_cast <int >(store->indices .size ()) && !store->indices [j].same_as
2089
- (load->indices [i]); j++);
2090
- } else {
2091
- all_var_axis = false ;
2092
- break ;
2093
- }
2094
- }
2095
- if (all_var_axis && j != static_cast <int >(store->indices .size ())) {
2096
- kind = relay::kBroadcast ;
2097
- kind_ = static_cast <int >(kind)>static_cast <int >(kind_)?kind:kind_;
2098
- return ;
2099
- }
2100
-
2101
- std::unordered_set<const VarNode*> vars;
2102
- for (int i = 0 ; i < static_cast <int >(store->indices .size ()); i++) {
2103
- if (const auto * v = store->indices [i].as <VarNode>()) {
2104
- vars.insert (v);
2105
- }
2106
- }
2107
- if (vars.size () == store->indices .size ()) {
2108
- bool use_other_var = false ;
2109
- for (int i = 0 ; i < static_cast <int >(load->indices .size ()); i++) {
2110
- if (tir::UsesVar (load->indices [i],
2111
- [&vars](const VarNode* var) { return !vars.count (var); })) {
2112
- use_other_var = true ;
2113
- break ;
2114
- }
2115
- }
2116
- if (!use_other_var) {
2117
- kind = relay::kInjective ;
2118
- kind_ = static_cast <int >(kind) > static_cast <int >(kind_) ? kind : kind_;
2119
- return ;
2120
- }
2121
- }
2122
- }
2123
- }
2124
- // test whether is reduce
2125
2131
for (IterVar it : op->iter_vars ) {
2126
2132
if (it->iter_type == kCommReduce ) {
2127
- kind = relay::kCommReduce ;
2128
- kind_ = static_cast <int >(kind)>static_cast <int >(kind_)?kind:kind_;
2133
+ kind_ = std::max (kind_, relay::kCommReduce );
2129
2134
return ;
2130
2135
}
2131
2136
}
2132
2137
2133
- kind_ = static_cast < int >(kind)> static_cast < int >(kind_)?kind:kind_ ;
2138
+ kind_ = relay:: kOpaque ;
2134
2139
}
2135
2140
2136
- Array<Array<PrimExpr>> indices_;
2141
+ Array<PrimExpr> store_indices_;
2142
+ Array<Array<PrimExpr>> load_indices_;
2137
2143
relay::OpPatternKind kind_ =relay::kElemWise ;
2138
2144
2139
2145
public:
0 commit comments