@@ -1917,100 +1917,106 @@ bool CheckSameArray(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
1917
1917
return true ;
1918
1918
}
1919
1919
1920
+ bool CheckElemwisePattern (const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r) {
1921
+ if (indices_l.size () != indices_r.size ()) {
1922
+ return false ;
1923
+ }
1924
+ int n = indices_l.size ();
1925
+ for (int i = 0 ; i < n; i++) {
1926
+ if (!indices_l[i].same_as (indices_r[i])) {
1927
+ return false ;
1928
+ }
1929
+ }
1930
+ return true ;
1931
+ }
1932
+
1933
+ bool CheckBroadcastPattern (const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
1934
+ if (indices_l.size () < indices_r.size ()) {
1935
+ return false ;
1936
+ }
1937
+ int j=0 ;
1938
+ for (int i = 0 ; i < static_cast <int >(indices_r.size ()); i++) {
1939
+ for (; j < static_cast <int >(indices_l.size ()) && !indices_l[j].same_as
1940
+ (indices_r[i]); j++);
1941
+ if (j==static_cast <int >(indices_l.size ())){
1942
+ return false ;
1943
+ }
1944
+ }
1945
+ return true ;
1946
+ }
1947
+
1948
+ bool CheckInjectivePattern (const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
1949
+ std::unordered_set<const VarNode*> vars;
1950
+ for (int i = 0 ; i < static_cast <int >(indices_l.size ()); i++) {
1951
+ if (const auto * v = indices_l[i].as <VarNode>()) {
1952
+ vars.insert (v);
1953
+ } else {
1954
+ return false ;
1955
+ }
1956
+ }
1957
+ for (int i = 0 ; i < static_cast <int >(indices_r.size ()); i++) {
1958
+ if (tir::UsesVar (indices_r[i],
1959
+ [&vars](const VarNode* var) { return !vars.count (var); })) {
1960
+ return false ;
1961
+ }
1962
+ }
1963
+ return true ;
1964
+ }
1965
+
1920
1966
class PatternKindAnalyzer : public StmtExprVisitor {
1921
1967
void VisitStmt_ (const BufferStoreNode* op) final {
1922
- indices_. push_back ( op->indices ) ;
1968
+ store_indices_ = op->indices ;
1923
1969
StmtVisitor::VisitStmt_ (op);
1924
1970
}
1925
1971
void VisitExpr_ (const BufferLoadNode* op) final {
1926
- indices_ .push_back (op->indices );
1972
+ load_indices_ .push_back (op->indices );
1927
1973
ExprVisitor::VisitExpr_ (op);
1928
1974
}
1929
1975
1930
- void VisitExpr_ (const CallNode* op) final {
1931
- kind_=relay::kOpaque ;
1932
- }
1933
-
1934
1976
void VisitStmt_ (const BlockNode* op)final {
1935
1977
if (op->name_hint == " root" ) {
1936
1978
StmtVisitor::VisitStmt (op->body );
1937
1979
return ;
1938
1980
}
1939
-
1940
- relay::OpPatternKind kind = relay::kOpaque ;
1941
1981
1942
- // test whether is elemwise
1943
- indices_ .clear ();
1982
+ load_indices_. clear ();
1983
+ store_indices_ .clear ();
1944
1984
StmtVisitor::VisitStmt (op->body );
1945
- bool same_index = true ;
1946
- for (int i = 1 ; i < static_cast <int >(indices_.size ()); i++) {
1947
- if (!CheckSameArray (indices_[0 ],indices_[i])) {
1948
- same_index = false ;
1949
- break ;
1985
+
1986
+ relay::OpPatternKind index_pair_pattern = relay::kElemWise ;
1987
+ if (load_indices_.empty ()) {
1988
+ index_pair_pattern = relay::kBroadcast ;
1989
+ } else {
1990
+ for (int i = 0 ; i < static_cast <int >(load_indices_.size ()); i++) {
1991
+ if (CheckElemwisePattern (store_indices_, load_indices_[i])) {
1992
+ index_pair_pattern = std::max (index_pair_pattern, relay::kElemWise );
1993
+ } else if (CheckBroadcastPattern (store_indices_, load_indices_[i])) {
1994
+ index_pair_pattern = std::max (index_pair_pattern, relay::kBroadcast );
1995
+ } else if (CheckInjectivePattern (store_indices_, load_indices_[i])) {
1996
+ index_pair_pattern = std::max (index_pair_pattern, relay::kInjective );
1997
+ } else {
1998
+ index_pair_pattern = relay::kOpaque ;
1999
+ break ;
2000
+ }
1950
2001
}
1951
2002
}
1952
- if (same_index) {
1953
- kind = relay::kElemWise ;
1954
- kind_ = static_cast <int >(kind)>static_cast <int >(kind_)?kind:kind_;
2003
+ if (index_pair_pattern != relay::kOpaque ) {
2004
+ kind_ = std::max (kind_, index_pair_pattern);
1955
2005
return ;
1956
2006
}
1957
2007
1958
- if (const auto * store = op->body .as <BufferStoreNode>()) {
1959
- if (const auto * load = store->value .as <BufferLoadNode>()) {
1960
- // test whether is broadcast
1961
- int j = 0 ;
1962
- bool all_var_axis = true ;
1963
- for (int i = 0 ; i < static_cast <int >(load->indices .size ()); i++) {
1964
- if (load->indices [i].as <VarNode>()) {
1965
- for (; j < static_cast <int >(store->indices .size ()) && !store->indices [j].same_as
1966
- (load->indices [i]); j++);
1967
- } else {
1968
- all_var_axis = false ;
1969
- break ;
1970
- }
1971
- }
1972
- if (all_var_axis && j != static_cast <int >(store->indices .size ())) {
1973
- kind = relay::kBroadcast ;
1974
- kind_ = static_cast <int >(kind)>static_cast <int >(kind_)?kind:kind_;
1975
- return ;
1976
- }
1977
-
1978
- std::unordered_set<const VarNode*> vars;
1979
- for (int i = 0 ; i < static_cast <int >(store->indices .size ()); i++) {
1980
- if (const auto * v = store->indices [i].as <VarNode>()) {
1981
- vars.insert (v);
1982
- }
1983
- }
1984
- if (vars.size () == store->indices .size ()) {
1985
- bool use_other_var = false ;
1986
- for (int i = 0 ; i < static_cast <int >(load->indices .size ()); i++) {
1987
- if (tir::UsesVar (load->indices [i],
1988
- [&vars](const VarNode* var) { return !vars.count (var); })) {
1989
- use_other_var = true ;
1990
- break ;
1991
- }
1992
- }
1993
- if (!use_other_var) {
1994
- kind = relay::kInjective ;
1995
- kind_ = static_cast <int >(kind) > static_cast <int >(kind_) ? kind : kind_;
1996
- return ;
1997
- }
1998
- }
1999
- }
2000
- }
2001
- // test whether is reduce
2002
2008
for (IterVar it : op->iter_vars ) {
2003
2009
if (it->iter_type == kCommReduce ) {
2004
- kind = relay::kCommReduce ;
2005
- kind_ = static_cast <int >(kind)>static_cast <int >(kind_)?kind:kind_;
2010
+ kind_ = std::max (kind_, relay::kCommReduce );
2006
2011
return ;
2007
2012
}
2008
2013
}
2009
2014
2010
- kind_ = static_cast < int >(kind)> static_cast < int >(kind_)?kind:kind_ ;
2015
+ kind_ = relay:: kOpaque ;
2011
2016
}
2012
2017
2013
- Array<Array<PrimExpr>> indices_;
2018
+ Array<PrimExpr> store_indices_;
2019
+ Array<Array<PrimExpr>> load_indices_;
2014
2020
relay::OpPatternKind kind_ =relay::kElemWise ;
2015
2021
2016
2022
public:
0 commit comments