Skip to content

Commit 46685cb

Browse files
jinhongyiiMasterJH5574
authored andcommitted
Fix Op Pattern Detection (tlc-pack#5)
1 parent ab6c35a commit 46685cb

File tree

2 files changed

+96
-66
lines changed

2 files changed

+96
-66
lines changed

src/tir/schedule/analysis/analysis.cc

+72-66
Original file line numberDiff line numberDiff line change
@@ -1917,100 +1917,106 @@ bool CheckSameArray(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
19171917
return true;
19181918
}
19191919

1920+
bool CheckElemwisePattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r) {
1921+
if (indices_l.size() != indices_r.size()) {
1922+
return false;
1923+
}
1924+
int n = indices_l.size();
1925+
for (int i = 0; i < n; i++) {
1926+
if (!indices_l[i].same_as(indices_r[i])) {
1927+
return false;
1928+
}
1929+
}
1930+
return true;
1931+
}
1932+
1933+
bool CheckBroadcastPattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
1934+
if (indices_l.size() < indices_r.size()) {
1935+
return false;
1936+
}
1937+
int j=0;
1938+
for (int i = 0; i < static_cast<int>(indices_r.size()); i++) {
1939+
for (; j < static_cast<int>(indices_l.size()) && !indices_l[j].same_as
1940+
(indices_r[i]); j++);
1941+
if(j==static_cast<int>(indices_l.size())){
1942+
return false;
1943+
}
1944+
}
1945+
return true;
1946+
}
1947+
1948+
bool CheckInjectivePattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r){
1949+
std::unordered_set<const VarNode*> vars;
1950+
for (int i = 0; i < static_cast<int>(indices_l.size()); i++) {
1951+
if (const auto* v = indices_l[i].as<VarNode>()) {
1952+
vars.insert(v);
1953+
} else {
1954+
return false;
1955+
}
1956+
}
1957+
for (int i = 0; i < static_cast<int>(indices_r.size()); i++) {
1958+
if (tir::UsesVar(indices_r[i],
1959+
[&vars](const VarNode* var) { return !vars.count(var); })) {
1960+
return false;
1961+
}
1962+
}
1963+
return true;
1964+
}
1965+
19201966
class PatternKindAnalyzer: public StmtExprVisitor {
19211967
void VisitStmt_(const BufferStoreNode* op) final {
1922-
indices_.push_back(op->indices);
1968+
store_indices_ = op->indices;
19231969
StmtVisitor::VisitStmt_(op);
19241970
}
19251971
void VisitExpr_(const BufferLoadNode* op) final {
1926-
indices_.push_back(op->indices);
1972+
load_indices_.push_back(op->indices);
19271973
ExprVisitor::VisitExpr_(op);
19281974
}
19291975

1930-
void VisitExpr_(const CallNode* op) final {
1931-
kind_=relay::kOpaque;
1932-
}
1933-
19341976
void VisitStmt_(const BlockNode* op)final {
19351977
if (op->name_hint == "root") {
19361978
StmtVisitor::VisitStmt(op->body);
19371979
return;
19381980
}
1939-
1940-
relay::OpPatternKind kind = relay::kOpaque;
19411981

1942-
//test whether is elemwise
1943-
indices_.clear();
1982+
load_indices_.clear();
1983+
store_indices_.clear();
19441984
StmtVisitor::VisitStmt(op->body);
1945-
bool same_index = true;
1946-
for (int i = 1; i < static_cast<int>(indices_.size()); i++) {
1947-
if(!CheckSameArray(indices_[0],indices_[i])) {
1948-
same_index = false;
1949-
break;
1985+
1986+
relay::OpPatternKind index_pair_pattern = relay::kElemWise;
1987+
if (load_indices_.empty()) {
1988+
index_pair_pattern = relay::kBroadcast;
1989+
} else {
1990+
for (int i = 0; i < static_cast<int>(load_indices_.size()); i++) {
1991+
if (CheckElemwisePattern(store_indices_, load_indices_[i])) {
1992+
index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise);
1993+
} else if (CheckBroadcastPattern(store_indices_, load_indices_[i])) {
1994+
index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast);
1995+
} else if (CheckInjectivePattern(store_indices_, load_indices_[i])) {
1996+
index_pair_pattern = std::max(index_pair_pattern, relay::kInjective);
1997+
} else {
1998+
index_pair_pattern = relay::kOpaque;
1999+
break;
2000+
}
19502001
}
19512002
}
1952-
if (same_index) {
1953-
kind = relay::kElemWise;
1954-
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
2003+
if (index_pair_pattern != relay::kOpaque) {
2004+
kind_ = std::max(kind_, index_pair_pattern);
19552005
return;
19562006
}
19572007

1958-
if (const auto* store = op->body.as<BufferStoreNode>()) {
1959-
if (const auto* load = store->value.as<BufferLoadNode>()) {
1960-
//test whether is broadcast
1961-
int j = 0;
1962-
bool all_var_axis = true;
1963-
for (int i = 0; i < static_cast<int>(load->indices.size()); i++) {
1964-
if (load->indices[i].as<VarNode>()) {
1965-
for (; j < static_cast<int>(store->indices.size()) && !store->indices[j].same_as
1966-
(load->indices[i]); j++);
1967-
} else {
1968-
all_var_axis = false;
1969-
break;
1970-
}
1971-
}
1972-
if (all_var_axis && j != static_cast<int>(store->indices.size())) {
1973-
kind = relay::kBroadcast;
1974-
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
1975-
return;
1976-
}
1977-
1978-
std::unordered_set<const VarNode*> vars;
1979-
for (int i = 0; i < static_cast<int>(store->indices.size()); i++) {
1980-
if (const auto* v = store->indices[i].as<VarNode>()) {
1981-
vars.insert(v);
1982-
}
1983-
}
1984-
if (vars.size() == store->indices.size()) {
1985-
bool use_other_var = false;
1986-
for (int i = 0; i < static_cast<int>(load->indices.size()); i++) {
1987-
if (tir::UsesVar(load->indices[i],
1988-
[&vars](const VarNode* var) { return !vars.count(var); })) {
1989-
use_other_var = true;
1990-
break;
1991-
}
1992-
}
1993-
if (!use_other_var) {
1994-
kind = relay::kInjective;
1995-
kind_ = static_cast<int>(kind) > static_cast<int>(kind_) ? kind : kind_;
1996-
return;
1997-
}
1998-
}
1999-
}
2000-
}
2001-
//test whether is reduce
20022008
for (IterVar it : op->iter_vars) {
20032009
if (it->iter_type == kCommReduce) {
2004-
kind = relay::kCommReduce;
2005-
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
2010+
kind_ = std::max(kind_, relay::kCommReduce);
20062011
return;
20072012
}
20082013
}
20092014

2010-
kind_ = static_cast<int>(kind)>static_cast<int>(kind_)?kind:kind_;
2015+
kind_ = relay::kOpaque;
20112016
}
20122017

2013-
Array<Array<PrimExpr>> indices_;
2018+
Array<PrimExpr> store_indices_;
2019+
Array<Array<PrimExpr>> load_indices_;
20142020
relay::OpPatternKind kind_ =relay::kElemWise;
20152021

20162022
public:

tests/python/relax/test_transform.py

+24
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,30 @@ def foo(x: Tensor[(m, n), "float32"], w: Tensor[(n, k), "float32"]) -> Tensor:
457457
new_mod =relax.transform.AnnotateOpKind()(mod)
458458
assert new_mod["injective"].attrs["op_pattern"] == 2
459459

460+
def test_annotate_op_kind_bias_add():
461+
@tvm.script.ir_module
462+
class InputModule:
463+
@T.prim_func
464+
def tir_bias_add(rxplaceholder_2: T.Buffer[(1, 1000), "float32"], rxplaceholder_3: T.Buffer[(1000,), "float32"], T_add_1: T.Buffer[(1, 1000), "float32"]) -> None:
465+
# function attr dict
466+
T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True, "op_pattern": 8})
467+
# body
468+
# with T.block("root")
469+
for i0, i1 in T.grid(1, 1000):
470+
with T.block("T_add"):
471+
ax0, ax1 = T.axis.remap("SS", [i0, i1])
472+
T.reads(rxplaceholder_2[ax0, ax1], rxplaceholder_3[ax1])
473+
T.writes(T_add_1[ax0, ax1])
474+
T_add_1[ax0, ax1] = rxplaceholder_2[ax0, ax1] + rxplaceholder_3[ax1]
475+
476+
@R.function
477+
def foo(x: Tensor[(1, 1000), "float32"], y: Tensor[(1000, ), "float32"]) -> Tensor:
478+
gv0 = R.call_tir(tir_bias_add, (x, y), (1, 1000), dtype="float32")
479+
return gv0
480+
481+
mod = InputModule
482+
new_mod =relax.transform.AnnotateOpKind()(mod)
483+
assert new_mod["tir_bias_add"].attrs["op_pattern"] == 1
460484

461485
def test_layout_rewrite():
462486
@tvm.script.ir_module

0 commit comments

Comments
 (0)