Skip to content
This repository was archived by the owner on May 22, 2023. It is now read-only.

Commit b2cf067

Browse files
committed
[USMP] Implement AssignPoolInfo pass
1 parent 2c15370 commit b2cf067

File tree

3 files changed

+485
-0
lines changed

3 files changed

+485
-0
lines changed

python/tvm/relax/transform/transform.py

+14
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,20 @@ def FuseTIR() -> tvm.ir.transform.Pass:
370370
return _ffi_api.FuseTIR()
371371

372372

373+
def AssignPoolInfo() -> tvm.ir.transform.Pass:
374+
"""Assign PoolInfo objects to Relax and TIR allocates depending on the function target
375+
376+
This pass would assign default PoolInfo objects to allocates that are not otherwise
377+
annotated, depending on pool info supplied for each target.
378+
379+
Returns
380+
-------
381+
ret : tvm.transform.Pass
382+
The registered pass for assigning pool infos.
383+
"""
384+
return _ffi_api.AssignPoolInfo()
385+
386+
373387
def _wrap_class_function_pass(pass_cls, pass_info):
374388
"""Wrap a python class as function pass."""
375389

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <tvm/target/target.h>
21+
#include <tvm/tir/stmt_functor.h>
22+
#include <tvm/tir/usmp/utils.h>
23+
24+
#include <string>
25+
#include <utility>
26+
27+
#include "tvm/relax/attrs/memory.h"
28+
#include "tvm/relax/expr_functor.h"
29+
30+
namespace tvm {
31+
32+
/*! \brief Assign PoolInfo objects to allocate that does not have any.
33+
* The schedulers have the oppurtunity to assign PoolInfo objects to
34+
* allocate nodes. However, each allocate node is expected to have
35+
* at least one PoolInfo node assigned to it. If it was not the case,
36+
* this Pass will assign all PoolInfo objects that the target could
37+
* access.*/
38+
39+
namespace tir {
40+
namespace usmp {
41+
42+
class TIRPoolInfoAssigner : public StmtExprMutator {
43+
public:
44+
explicit TIRPoolInfoAssigner(PrimFunc func, const Map<String, Array<PoolInfo>>& target_pool_infos,
45+
const Map<String, Array<PoolInfo>>& target_const_pool_infos)
46+
: func_(std::move(func)),
47+
target_pool_infos_(target_pool_infos),
48+
target_const_pool_infos_(target_const_pool_infos){};
49+
50+
Stmt operator()();
51+
52+
private:
53+
Stmt VisitStmt_(const AllocateNode* op) override;
54+
Stmt VisitStmt_(const AllocateConstNode* op) override;
55+
56+
PrimFunc func_;
57+
Map<String, Array<PoolInfo>> target_pool_infos_;
58+
Map<String, Array<PoolInfo>> target_const_pool_infos_;
59+
};
60+
61+
Stmt TIRPoolInfoAssigner::operator()() {
62+
return this->VisitStmt(func_->body);
63+
}
64+
65+
Stmt TIRPoolInfoAssigner::VisitStmt_(const AllocateNode* op) {
66+
Optional<Target> tgt = func_->GetAttr<Target>(tvm::attr::kTarget).value();
67+
ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_;
68+
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(op->annotations);
69+
if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) {
70+
ICHECK(target_pool_infos_.count(tgt.value()->str()) > 0)
71+
<< "Target " << PrettyPrint(tgt) << " not found among " << PrettyPrint(target_pool_infos_);
72+
annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()->str()]);
73+
}
74+
Stmt body = VisitStmt(op->body);
75+
auto allocate =
76+
Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body, annotations);
77+
return std::move(allocate);
78+
}
79+
80+
Stmt TIRPoolInfoAssigner::VisitStmt_(const AllocateConstNode* op) {
81+
if (!target_const_pool_infos_.size()) {
82+
return StmtExprMutator::VisitStmt_(op);
83+
}
84+
Optional<Target> tgt = func_->GetAttr<Target>(tvm::attr::kTarget).value();
85+
ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_;
86+
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(op->annotations);
87+
if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) {
88+
annotations.Set(kPoolCandidatesAllocateAttr, target_const_pool_infos_[tgt.value()->str()]);
89+
annotations.Set(kTargetPoolReadOnlyAccess, Integer(1));
90+
}
91+
Stmt body = VisitStmt(op->body);
92+
auto allocate_const =
93+
AllocateConst(op->buffer_var, op->dtype, op->extents, op->data, body, annotations);
94+
return std::move(allocate_const);
95+
}
96+
97+
} // namespace usmp
98+
} // namespace tir
99+
100+
namespace relax {
101+
namespace usmp {
102+
103+
class RelaxPoolInfoAssigner : public ExprMutator {
104+
public:
105+
explicit RelaxPoolInfoAssigner(Function func, const Map<String, Array<PoolInfo>>& target_pool_infos,
106+
const Map<String, Array<PoolInfo>>& target_const_pool_infos)
107+
: func_(std::move(func)),
108+
target_pool_infos_(target_pool_infos),
109+
target_const_pool_infos_(target_const_pool_infos){};
110+
111+
Expr operator()();
112+
113+
private:
114+
Expr VisitExpr_(const CallNode* op) override;
115+
116+
Function func_;
117+
Map<String, Array<PoolInfo>> target_pool_infos_;
118+
Map<String, Array<PoolInfo>> target_const_pool_infos_;
119+
};
120+
121+
Expr RelaxPoolInfoAssigner::operator()() {
122+
return this->VisitExpr(func_->body);
123+
}
124+
125+
Expr RelaxPoolInfoAssigner::VisitExpr_(const CallNode* call) {
126+
Expr expr = VisitExprPostOrder_(call);
127+
call = expr.as<CallNode>();
128+
129+
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
130+
if (call->op != alloc_tensor_op) {
131+
return GetRef<Call>(call);
132+
}
133+
Optional<Target> tgt = func_->GetAttr<Target>(tvm::attr::kTarget).value();
134+
ICHECK(tgt) << "The following Func does not have a target attr: \n" << func_;
135+
auto alloc_attrs = call->attrs.as<AllocTensorAttrs>();
136+
ICHECK(alloc_attrs != nullptr) << "must be AllocTensorAttrs";
137+
if (alloc_attrs->candidate_memory_pools.size() > 0) {
138+
return GetRef<Call>(call);
139+
}
140+
ICHECK(target_pool_infos_.count(tgt.value()->str()) > 0)
141+
<< "Target " << PrettyPrint(tgt) << " not found among " << PrettyPrint(target_pool_infos_);
142+
auto alloc_tensor_attr = make_object<AllocTensorAttrs>();
143+
alloc_tensor_attr->dtype = alloc_attrs->dtype;
144+
alloc_tensor_attr->runtime_device_index = alloc_attrs->runtime_device_index;
145+
alloc_tensor_attr->candidate_memory_pools = target_pool_infos_[tgt.value()->str()];
146+
auto allocate_call =
147+
Call(call->op, call->args, Attrs(alloc_tensor_attr), call->type_args, call->span);
148+
return std::move(allocate_call);
149+
}
150+
151+
} // namespace usmp
152+
} // namespace relax
153+
154+
class PoolInfoAssigner {
155+
public:
156+
explicit PoolInfoAssigner(const IRModule& module) {
157+
auto main_func =
158+
Downcast<relax::Function>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
159+
ICHECK(main_func.defined()) << "main function is not in the module";
160+
Optional<Target> target_host = main_func->GetAttr<Target>(tvm::attr::kTarget);
161+
ICHECK(target_host) << "main function does not have a target attr";
162+
WorkspaceMemoryPools workspace_pools =
163+
module->GetAttr<WorkspaceMemoryPools>(tvm::attr::kWorkspaceMemoryPools)
164+
.value_or(WorkspaceMemoryPools({CreateDefaultWorkspaceMemoryPool(module)}));
165+
// make default ConstantPoolInfo if no constant and no workspace pool infos supplied
166+
ConstantMemoryPools constant_pools =
167+
module->GetAttr<ConstantMemoryPools>(tvm::attr::kConstantMemoryPools)
168+
.value_or(
169+
module->GetAttr<WorkspaceMemoryPools>(tvm::attr::kWorkspaceMemoryPools).defined()
170+
? ConstantMemoryPools()
171+
: ConstantMemoryPools({CreateDefaultConstantMemoryPool(module)}));
172+
auto to_map = [](auto pool_infos) {
173+
Map<String, Array<PoolInfo>> pool_map;
174+
for (const PoolInfo& pool_info : pool_infos) {
175+
for (const auto& tgt : pool_info->targets) {
176+
if (pool_map.find(tgt->str()) == pool_map.end()) {
177+
pool_map.Set(tgt->str(), Array<PoolInfo>());
178+
}
179+
Array<PoolInfo> pool_info_arr = pool_map[tgt->str()];
180+
pool_info_arr.push_back(pool_info);
181+
pool_map.Set(tgt->str(), pool_info_arr);
182+
}
183+
}
184+
return pool_map;
185+
};
186+
187+
target_pool_infos_ = to_map(workspace_pools->pools);
188+
if (constant_pools.defined()) {
189+
target_const_pool_infos_ = to_map(constant_pools->pools);
190+
}
191+
mod_ = module->ShallowCopy();
192+
}
193+
194+
IRModule operator()();
195+
196+
private:
197+
IRModule mod_;
198+
Map<String, Array<PoolInfo>> target_pool_infos_;
199+
Map<String, Array<PoolInfo>> target_const_pool_infos_;
200+
WorkspacePoolInfo CreateDefaultWorkspaceMemoryPool(const IRModule& module);
201+
ConstantPoolInfo CreateDefaultConstantMemoryPool(const IRModule& module) {
202+
auto p = CreateDefaultWorkspaceMemoryPool(module);
203+
return ConstantPoolInfo(
204+
"global_const_workspace", {p->targets}, {},
205+
PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth,
206+
kUnknownWriteBandwidth, 0, 0, {p->target_burst_bytes}, Bool(true)));
207+
}
208+
};
209+
210+
WorkspacePoolInfo PoolInfoAssigner::CreateDefaultWorkspaceMemoryPool(const tvm::IRModule& module) {
211+
VLOG(1) << "Creating default memory pool for:" << std::endl << PrettyPrint(module);
212+
Map<Target, String> target_access;
213+
auto main_func =
214+
Downcast<tvm::BaseFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
215+
Target target_host = main_func->GetAttr<Target>(tvm::attr::kTarget).value();
216+
for (const auto& kv : module->functions) {
217+
BaseFunc func = kv.second;
218+
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
219+
target_access.Set(target.value_or(target_host), kTargetPoolReadWriteAccess);
220+
}
221+
Array<Target> targets;
222+
for (const auto& kv : target_access) {
223+
bool exist = false;
224+
// Exclude targets with the same string representation
225+
for (const auto& t : targets) {
226+
if (t->str() == kv.first->str()) {
227+
exist = true;
228+
}
229+
}
230+
if (!exist) {
231+
targets.push_back(kv.first);
232+
}
233+
}
234+
return WorkspacePoolInfo(
235+
"global_workspace", targets,
236+
PoolInfoProperties(kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth,
237+
kUnknownWriteBandwidth, 0, 0, {{target_host, 1}}, Bool(true)));
238+
}
239+
240+
IRModule PoolInfoAssigner::operator()() {
241+
for (const auto& kv : mod_->functions) {
242+
GlobalVar gv = kv.first;
243+
if (kv.second->IsInstance<relax::FunctionNode>()) {
244+
using RelaxPoolInfoAssigner = relax::usmp::RelaxPoolInfoAssigner;
245+
using Function = relax::Function;
246+
auto func = runtime::Downcast<Function>(kv.second);
247+
RelaxPoolInfoAssigner relax_pool_info_assigner = RelaxPoolInfoAssigner(func, target_pool_infos_, target_const_pool_infos_);
248+
relax::Expr body = relax_pool_info_assigner();
249+
Function new_relax_func = Function(func->params, body, func->ret_type, func->attrs, func->span);
250+
mod_->Update(gv, new_relax_func);
251+
} else if (kv.second->IsInstance<tir::PrimFuncNode>()) {
252+
using TIRPoolInfoAssigner = tir::usmp::TIRPoolInfoAssigner;
253+
using PrimFunc = tir::PrimFunc;
254+
auto func = Downcast<PrimFunc>(kv.second);
255+
TIRPoolInfoAssigner tir_pool_info_assigner = TIRPoolInfoAssigner(func, target_pool_infos_, target_const_pool_infos_);
256+
tir::Stmt body = tir_pool_info_assigner();
257+
PrimFunc new_prim_func = PrimFunc(func->params, body, func->ret_type, func->buffer_map,
258+
func->preflattened_buffer_map, func->attrs);
259+
mod_->Update(gv, new_prim_func);
260+
}
261+
}
262+
return mod_;
263+
}
264+
265+
namespace transform {
266+
267+
tvm::transform::Pass AssignPoolInfo() {
268+
auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
269+
return PoolInfoAssigner(m)();
270+
};
271+
return tvm::transform::CreateModulePass(pass_func, 0, "relax.usmp.AssignPoolInfo", {});
272+
}
273+
274+
TVM_REGISTER_GLOBAL("relax.transform.AssignPoolInfo").set_body_typed(AssignPoolInfo);
275+
276+
} // namespace transform
277+
} // namespace tvm

0 commit comments

Comments
 (0)