Skip to content

Commit f95b486

Browse files
jinhongyiiMasterJH5574
authored andcommitted
Fix Op Pattern Detection (tlc-pack#5)
1 parent d7c172c commit f95b486

File tree

2 files changed

+96
-66
lines changed

2 files changed

+96
-66
lines changed

src/tir/schedule/analysis/analysis.cc

+72-66
Original file line numberDiff line numberDiff line change
@@ -2040,100 +2040,106 @@ bool CheckSameArray(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
20402040
return true;
20412041
}
20422042

2043+
bool CheckElemwisePattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r) {
2044+
if (indices_l.size() != indices_r.size()) {
2045+
return false;
2046+
}
2047+
int n = indices_l.size();
2048+
for (int i = 0; i < n; i++) {
2049+
if (!indices_l[i].same_as(indices_r[i])) {
2050+
return false;
2051+
}
2052+
}
2053+
return true;
2054+
}
2055+
2056+
bool CheckBroadcastPattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
2057+
if (indices_l.size() < indices_r.size()) {
2058+
return false;
2059+
}
2060+
int j=0;
2061+
for (int i = 0; i < static_cast<int>(indices_r.size()); i++) {
2062+
for (; j < static_cast<int>(indices_l.size()) && !indices_l[j].same_as
2063+
(indices_r[i]); j++);
2064+
if(j==static_cast<int>(indices_l.size())){
2065+
return false;
2066+
}
2067+
}
2068+
return true;
2069+
}
2070+
2071+
bool CheckInjectivePattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
2072+
std::unordered_set<const VarNode*> vars;
2073+
for (int i = 0; i < static_cast<int>(indices_l.size()); i++) {
2074+
if (const auto* v = indices_l[i].as<VarNode>()) {
2075+
vars.insert(v);
2076+
} else {
2077+
return false;
2078+
}
2079+
}
2080+
for (int i = 0; i < static_cast<int>(indices_r.size()); i++) {
2081+
if (tir::UsesVar(indices_r[i],
2082+
[&vars](const VarNode* var) { return !vars.count(var); })) {
2083+
return false;
2084+
}
2085+
}
2086+
return true;
2087+
}
2088+
20432089
class PatternKindAnalyzer: public StmtExprVisitor {
20442090
void VisitStmt_(const BufferStoreNode* op) final {
2045-
indices_.push_back(op->indices);
2091+
store_indices_ = op->indices;
20462092
StmtVisitor::VisitStmt_(op);
20472093
}
20482094
void VisitExpr_(const BufferLoadNode* op) final {
2049-
indices_.push_back(op->indices);
2095+
load_indices_.push_back(op->indices);
20502096
ExprVisitor::VisitExpr_(op);
20512097
}
20522098

2053-
void VisitExpr_(const CallNode* op) final {
2054-
kind_=relay::kOpaque;
2055-
}
2056-
20572099
void VisitStmt_(const BlockNode* op)final {
20582100
if (op->name_hint == "root") {
20592101
StmtVisitor::VisitStmt(op->body);
20602102
return;
20612103
}
2062-
2063-
relay::OpPatternKind kind = relay::kOpaque;
20642104

2065-
//test whether is elemwise
2066-
indices_.clear();
2105+
load_indices_.clear();
2106+
store_indices_.clear();
20672107
StmtVisitor::VisitStmt(op->body);
2068-
bool same_index = true;
2069-
for (int i = 1; i < static_cast<int>(indices_.size()); i++) {
2070-
if(!CheckSameArray(indices_[0],indices_[i])) {
2071-
same_index = false;
2072-
break;
2108+
2109+
relay::OpPatternKind index_pair_pattern = relay::kElemWise;
2110+
if (load_indices_.empty()) {
2111+
index_pair_pattern = relay::kBroadcast;
2112+
} else {
2113+
for (int i = 0; i < static_cast<int>(load_indices_.size()); i++) {
2114+
if (CheckElemwisePattern(store_indices_, load_indices_[i])) {
2115+
index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise);
2116+
} else if (CheckBroadcastPattern(store_indices_, load_indices_[i])) {
2117+
index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast);
2118+
} else if (CheckInjectivePattern(store_indices_, load_indices_[i])) {
2119+
index_pair_pattern = std::max(index_pair_pattern, relay::kInjective);
2120+
} else {
2121+
index_pair_pattern = relay::kOpaque;
2122+
break;
2123+
}
20732124
}
20742125
}
2075-
if (same_index) {
2076-
kind = relay::kElemWise;
2077-
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
2126+
if (index_pair_pattern != relay::kOpaque) {
2127+
kind_ = std::max(kind_, index_pair_pattern);
20782128
return;
20792129
}
20802130

2081-
if (const auto* store = op->body.as<BufferStoreNode>()) {
2082-
if (const auto* load = store->value.as<BufferLoadNode>()) {
2083-
//test whether is broadcast
2084-
int j = 0;
2085-
bool all_var_axis = true;
2086-
for (int i = 0; i < static_cast<int>(load->indices.size()); i++) {
2087-
if (load->indices[i].as<VarNode>()) {
2088-
for (; j < static_cast<int>(store->indices.size()) && !store->indices[j].same_as
2089-
(load->indices[i]); j++);
2090-
} else {
2091-
all_var_axis = false;
2092-
break;
2093-
}
2094-
}
2095-
if (all_var_axis && j != static_cast<int>(store->indices.size())) {
2096-
kind = relay::kBroadcast;
2097-
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
2098-
return;
2099-
}
2100-
2101-
std::unordered_set<const VarNode*> vars;
2102-
for (int i = 0; i < static_cast<int>(store->indices.size()); i++) {
2103-
if (const auto* v = store->indices[i].as<VarNode>()) {
2104-
vars.insert(v);
2105-
}
2106-
}
2107-
if (vars.size() == store->indices.size()) {
2108-
bool use_other_var = false;
2109-
for (int i = 0; i < static_cast<int>(load->indices.size()); i++) {
2110-
if (tir::UsesVar(load->indices[i],
2111-
[&vars](const VarNode* var) { return !vars.count(var); })) {
2112-
use_other_var = true;
2113-
break;
2114-
}
2115-
}
2116-
if (!use_other_var) {
2117-
kind = relay::kInjective;
2118-
kind_ = static_cast<int>(kind) > static_cast<int>(kind_) ? kind : kind_;
2119-
return;
2120-
}
2121-
}
2122-
}
2123-
}
2124-
//test whether is reduce
21252131
for (IterVar it : op->iter_vars) {
21262132
if (it->iter_type == kCommReduce) {
2127-
kind = relay::kCommReduce;
2128-
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
2133+
kind_ = std::max(kind_, relay::kCommReduce);
21292134
return;
21302135
}
21312136
}
21322137

2133-
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
2138+
kind_ = relay::kOpaque;
21342139
}
21352140

2136-
Array<Array<PrimExpr>> indices_;
2141+
Array<PrimExpr> store_indices_;
2142+
Array<Array<PrimExpr>> load_indices_;
21372143
relay::OpPatternKind kind_ =relay::kElemWise;
21382144

21392145
public:

tests/python/relax/test_transform.py

+24
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,30 @@ def foo(x: Tensor[(m, n), "float32"], w: Tensor[(n, k), "float32"]) -> Tensor:
498498
new_mod =relax.transform.AnnotateOpKind()(mod)
499499
assert new_mod["injective"].attrs["op_pattern"] == 2
500500

501+
def test_annotate_op_kind_bias_add():
502+
@tvm.script.ir_module
503+
class InputModule:
504+
@T.prim_func
505+
def tir_bias_add(rxplaceholder_2: T.Buffer[(1, 1000), "float32"], rxplaceholder_3: T.Buffer[(1000,), "float32"], T_add_1: T.Buffer[(1, 1000), "float32"]) -> None:
506+
# function attr dict
507+
T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True, "op_pattern": 8})
508+
# body
509+
# with T.block("root")
510+
for i0, i1 in T.grid(1, 1000):
511+
with T.block("T_add"):
512+
ax0, ax1 = T.axis.remap("SS", [i0, i1])
513+
T.reads(rxplaceholder_2[ax0, ax1], rxplaceholder_3[ax1])
514+
T.writes(T_add_1[ax0, ax1])
515+
T_add_1[ax0, ax1] = rxplaceholder_2[ax0, ax1] + rxplaceholder_3[ax1]
516+
517+
@R.function
518+
def foo(x: Tensor[(1, 1000), "float32"], y: Tensor[(1000, ), "float32"]) -> Tensor:
519+
gv0 = R.call_tir(tir_bias_add, (x, y), (1, 1000), dtype="float32")
520+
return gv0
521+
522+
mod = InputModule
523+
new_mod =relax.transform.AnnotateOpKind()(mod)
524+
assert new_mod["tir_bias_add"].attrs["op_pattern"] == 1
501525

502526
def test_layout_rewrite():
503527
@tvm.script.ir_module

0 commit comments

Comments
 (0)