code:https://github.com/OAID/Tengine
version: 88b4b7a2
图片,代码都来自以上项目。
1. 简介
…
2. 正题
- 创建一个空的graph,
ir_graph_t* ir_graph = create_ir_graph((struct context*)context);
- 找到模型序列化器
struct serializer* loader = find_serializer_via_name(model_format);
,然后调用load_model
方法 - load_graph,会初始化
graph
的tensor_list
,node_list
,输入输出的node
,sub_graphes
if (load_graph_tensors(tm2_s, graph, priv) < 0) // 载入tensor_list
goto error;
if (load_graph_nodes(tm2_s, graph, priv) < 0) // 初始化node_list
goto error;
if (set_graph_io_nodes(tm2_s, graph, priv) < 0)
goto error;
if (load_graph_sub_info(tm2_s, graph, priv) < 0)
goto error;
load_graph_tensors
中会遍历预训练的模型,创建tensor
,包括type,data,shape等属性,并添加到graph.tensor_list
中。load_graph_nodes
中会遍历预训练模型中的nodes,创建node
,设置node的index,input_num, output_num, input_tensors, output_tensors, op
等属性,并添加到graph.node_list
source/operator/prototype
目录下所有op文件是用于node的op初始化,op->param_mem参数空间申请。
#include "graph/tensor.h"
#include "graph/node.h"
#include "graph/graph.h"
#include "module/module.h"
#include "utility/sys_port.h"
static int infer_shape(struct node* node) // static 函数指针的位置不变
{
struct graph* ir_graph = node->graph;
struct tensor* input = get_ir_graph_tensor(ir_graph, node->input_tensors[0]);
struct tensor* output = get_ir_graph_tensor(ir_graph, node->output_tensors[0]);
set_ir_tensor_shape(output, input->dims, input->dim_num); // 设置 output tensor 的尺寸
return 0;
}
static int init_op(struct op* op)
{
op->param_mem = NULL;
op->param_size = 0;
op->same_shape = 0;
op->infer_shape = infer_shape;
return 0;
}
static void release_op(struct op* op)
{
sys_free(op->param_mem);
}
int register_absval_op()
{
struct method m;
m.version = 1;
m.init = init_op; // 函数指针
m.release = release_op; // 函数指针
return register_op(OP_ABSVAL, OP_ABSVAL_NAME, &m); // OP_ABSVAL:enum, OP_ABSVAL_NAME:字符串
}
int unregister_absval_op()
{
return unregister_op(OP_ABSVAL, 1);
}
source/serializer/tmfile/op
文件夹下的所有op文件,是对node->op->param_mem进行赋值,上一步是申请内存初始化。如果用c++的话,可以把这两步的操作放到一个类中,框架更加清晰
static int batchnorm_op_map(int op)
{
return OP_BATCHNORM;
}
static int tm2_load_batchnorm(struct graph* ir_graph, struct node* ir_node, const TM2_Node* tm_node, const TM2_Operator* tm_op)
{
struct batchnorm_param* batchnorm_param = ( struct batchnorm_param* )ir_node->op.param_mem;
const struct tm2_priv* tm2_priv = (struct tm2_priv*)ir_graph->serializer_privacy;
const char* mem_base = tm2_priv->base;
const TM2_BatchNormParam* tm_param = ( TM2_BatchNormParam* )(mem_base + tm_op->offset_t_param);
batchnorm_param->rescale_factor = tm_param->rescale_factor; // op 的参数
batchnorm_param->eps = tm_param->eps;
batchnorm_param->caffe_flavor = tm_param->caffe_flavor;
return 0;
}