似乎有必要介绍下NNVM
NNVM不是一个深度学习库,而是一个去中心化的、轻量的深度学习系统构建辅助模块。
不同于很多深度学习系统提供端到端的解决方案,NNVM尝试为自定义一个深度学习系统提供可能和关键性模块。它提供一种通用的方式进行计算图优化,如内存节约和设备调度,而无关于运算接口定义和运算执行。它更像是神经网络、计算图生成和优化的中间表示层。
示例:
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2);
NNVM_REGISTER_OP(conv2d)
.describe("take 2d convolution of input")
.set_num_inputs(2);
其中,
#define NNVM_REGISTER_OP(OpName) \
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
#define NNVM_REGISTER_VAR_DEF(OpName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
在返过来看示例代码,NNVM_REGISTER_OP定义一个变量(变量名称通过__COUNTER__宏防止重名),并给它赋值为nnvm::Op对象的引用,后面可以接着使用.运算符,同理,describe和set_num_inputs也是返回该对象的引用,可级联使用.运算符,注意最后必须跟“;”
从 __REGISTER_OR_GET__ 可以看出这个如果原先注册了某一操作,会保存在内部的字典中,因此可以一个操作在不同地方进行注册。
在nnvm::Op中,除了describe、set_num_inputs还有一些的常用的调用:
std::function<void(NodeAttrs* attrs)> attr_parser = nullptr; // 函数封装, 属性解析器
inline Op& describe(const std::string& descr);
inline Op& add_argument(const std::string &name,
const std::string &type,
const std::string &description); //添加参数
inline Op& add_arguments(const std::vector<ParamFieldInfo> &args);
inline Op& set_num_inputs(uint32_t n); // NOLINT(*),设置输入数目
inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
inline Op& set_num_outputs(uint32_t n); // NOLINT(*),设置输出数目
inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*),设置属性解析器
template<typename ValueType>
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value,
int plevel = 10); // 添加属性,ValueType为某一函数封装
Op& add_alias(const std::string& alias); // 添加别名
Op& include(const std::string& group_name); // 从一个操作组中继承所有属性
注册完操作后,可以在前段代码里构造计算图,如:
import nnvm.symbol as nn
# symbolic variable
x = nn.Variable('x')
y = nn.Variable('y')
w = nn.Variable('w')
z = nn.conv2d(nn.add(x, y), w, filter_size=(2,2), name='conv1')
构造完成的图(Graph)通过Pass进行转换和优化,如符号求导、内存规划、类型/大小自动推断等。另,图的执行并不在NNVM中,需额外实现。