在执行GraphExecutorCodegen::Codegen时,一开始就调用GraphPlanMemory分配内存,这个函数的实现:
StaticMemoryPlan GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); }
这里实例化了一个StorageAllocator对象,并调用它的Plan方法。在StorageAllocator::Plan的一开始有:
prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
1 创建token表
StorageAllocaInit::GetInitTokenMap的实现
std::unordered_map<const ExprNode*, std::vector<StorageToken*>> GetInitTokenMap(
const Function& func) {
this->Run(func);
return std::move(token_map_);
}
StorageAllocaInit继承自StorageAllocaBaseVisitor类,StorageAllocaBaseVisitor::Run方法:
void Run(const Function& func) { VisitExpr(func); }
因为 StorageAllocaBaseVisitor继承自DeviceAwareExprVisitor, DeviceAwareExprVisitor继承自ExprVisitor。这个StorageAllocaBaseVisitor调用VisitExpr,就是调用ExprVisitor::VisitExpr。ExprVisitor::VisitExpr会根据传入的数据类型调用对应的VisitExpr_。而DeviceAwareExprVisitor和StorageAllocaBaseVisitor共同重载了各种类型表达式的遍历函数VisitExpr_。这样在Run中调用VisitExpr的时候,最终会走到各重载的VisitExpr_中,执行相应的操作:
详细的流程和机制可以参考【TVM源码学习笔记】3.1.1 VisitExpr流程分析
从上图可以看到,对函数定义、调用以及let这种复合表达式的遍历处理都在DeviceAwareExprVisitor中,而比较基础的语法单元,如变量,全局变量,常量,元组,元组访问以及if语句等的遍历处理定义在StorageAllocaBaseVisitor中。
这里还有几个接口我们了解下继承和重载关系:
我们先分析下函数节点的VisitExpr_的实现:
// TODO(mbs): We'd probably have less tedious code duplication if we redefined the memoizing
// mutator on top of the generic Functor.
void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) {
if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
// No tracking inside primitive functions.
DeviceAwareVisitExpr_(function_node);
} else {
// Function parameters come into scope.
for (auto param : function_node->params) {
PushBoundVar(param, param->virtual_device());
}
// Entering scope of function body.
PushVirtualDevice(function_node->virtual_device());
EnterFunctionBody();
DeviceAwareVisitExpr_(function_node);
// Leaving scope of function body.
ExitFunctionBody();
PopVirtualDevice();
// Function parameters go out of scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
PopBoundVar(function_node->params[i]);
}
}
}
处理函数定义节点的时候分两种情况:
1. 函数有attr::kPrimitive属性(即名字为Primitive的属性)且非零,调用DeviceAwareVisitExpr_处理函数定义;
2. 否则,将函数参数压栈,然后调用调用DeviceAwareVisitExpr_处理函数定义,处理完后出栈。
这里DeviceAwareVisitExpr_的参数是FunctionNode,根据前面类图我们可以知道,这里是调用的StorageAllocaBaseVisitor的DeviceAwareVisitExpr_:
void DeviceAwareVisitExpr_(const FunctionNode* func_node) final {
if (function_nesting() > 1) {
// do not recurse into sub functions.
return;
}
if (func_node->HasNonzeroAttr(attr::kPrimitive)) {
// No storage needed for primitive functions.
return;
}
for (const auto& param : func_node->params) {
CreateToken(param.get(), /*can_realloc=*/false);
}
// Process the function body, and make sure all result tokens are considered 'alive'.
for (StorageToken* tok : GetToken(func_node->body)) {
tok->ref_counter += 1;
}
}
这里对函数的处理仅仅只是对函数参数创建标识符节点。CreateToken定义在StorageAllocaBaseVisitor,调用CreateTokenOnDevice方法。该方法在StorageAllocaInit和StorageAllocator中分别实现。这里是StorageAllocaInit实例调进来的,所以在该类中找方法的实现:
void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device,
bool can_realloc) override {
ICHECK(!token_map_.count(op));
std::vector<StorageToken*> tokens;
for (const auto& ttype : FlattenTupleType(op->checked_type())) {
auto* token = arena_->make<StorageToken>();
token->ttype = ttype;
token->virtual_device = virtual_device;
tokens.push_back(token);
}
token_map_[op] = tokens;
}
op->checked_type()是每个算子自己定义的类型推理接口,详见【TVM源码学习笔记】Relay算子实现流程
这里对每个函数参数创建一个标识符结构体StorageToken,加入token_map_表;已经在token_map_表中的不会重复添加。
这里只是创建了StorageToken来创建token表,并没有为标记符对应的实际数据(例如tensor)分配空间。
在创建token表的过程中,StorageAllocaBaseVisitor中会对函数定义,函数调用,常量和tuple相关语法单元中的token加入token表,而对变量,全局变量,op, if等语法单元不做处理。这是为什么呢?
2 StorageAllocator::Run
回到StorageAllocator::Plan中,在创建token表后,执行了StorageAllocator::Run:
// Run storage allocation for a function.
StaticMemoryPlan Plan(const Function& func) {
VLOG_CONTEXT << "StorageAllocator";
VLOG(1) << "planning:" << std::endl << PrettyPrint(func);
prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
this->Run(func);
因为StorageAllocator并没有实现Run方法,所以这里Run和前面StorageAllocaInit::GetInitTokenMap一样,调用的是StorageAllocaBaseVisitor::Run,并且参数都一样。这样两者对各语法单元的遍历接口VisitExpr和VisitExpr_也都一样。那么两次Run的调用差别在哪里呢?最重要的差别在CreateTokenOnDevice()上:
StorageAllocaInit通过Run接口遍历模型的所有表达式,对各种语法单元的最后处理都落在CreateTokenOnDevice()里面,这里只是将token分配一个StorageToken内存,加入token表。
同样的流程,同样的处理,StorageAllocator最后调用到CreateTokenOnDevice()的时候,会为每个token分配实际的数据内存。这里内存分配涉及到内存管理。当前不做分析,后面会专门深入讨论。
我们继续看内存分配函数Plan:
// Run storage allocation for a function.
StaticMemoryPlan Plan(const Function& func) {
VLOG_CONTEXT << "StorageAllocator";
VLOG(1) << "planning:" << std::endl << PrettyPrint(func);
prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
this->Run(func);
// The value of smap contains two integer arrays where the first array
// contains the planned storage ids and the second holds the device types.
// smap的值包含两个整数数组,第一个数组是分配的空间id,第二个是(内存?)设备类型。
Map<Expr, backend::StorageInfo> smap;
int num_annotated_nodes = 0;
int num_nodes = 0;
//遍历token表
for (const auto& kv : token_map_) {
//三个vector分别记录storage_id, 表达式执行的设备,占用内存大小
std::vector<int64_t> storage_ids;
storage_ids.reserve(kv.second.size());
std::vector<VirtualDevice> virtual_devices;
virtual_devices.reserve(kv.second.size());
std::vector<int64_t> sid_sizes_byte;
sid_sizes_byte.reserve(kv.second.size());
//遍历表达式中的token
for (StorageToken* tok : kv.second) {
VLOG(1) << "token: " << tok->ToString();
if (tok->is_valid()) {
num_annotated_nodes++;
}
num_nodes++;
//记录token的storage_id,设备和内存大小
storage_ids.push_back(tok->storage_id);
virtual_devices.push_back(tok->virtual_device);
sid_sizes_byte.push_back(GetMemorySize(tok));
}
//为每个表达式都实例化一个backend::StorageInfo,加入smap表
auto storage_info = backend::StorageInfo(std::move(storage_ids), std::move(virtual_devices),
std::move(sid_sizes_byte));
//kv.first是表达式类型(constant, let, tuple)
smap.Set(GetRef<Expr>(kv.first), storage_info);
}
// Either all or none of the nodes should be annotated.
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes
<< "expressions are assigned with virtual device types. Either all "
"or none of the expressions are expected to be annotated.";
}
return backend::StaticMemoryPlan(smap);
}
整个流程:
1. 遍历模型的各表达式,创建token表;
2. 遍历模型各表达式,为token分配内存;
3. 将分配的内存编号,内存所在的(device)位置,内存大小打包返回