神经网络记载器(Neutral Network Loader)
背景
通过前面的学习熟悉了两种经典的前置神经网络,LeNet和VGG,在具体实现过程中,有不少重复代码,随着更多神经网络的引入,以及各种神经网络变种,发现代码非常难以维护,而且可读性不是很高,比如VGG网络:
Net | BatchNorm | Non-BatchNorm |
---|---|---|
A | 1 | 1 |
A-LRN | 1 | 1 |
B | 1 | 1 |
C | 1 | 1 |
D | 1 | 1 |
E | 1 | 1 |
总共有12中组合,如果按照之前的写法:
VGGNet::VGGNet(int num_classes)
: C1 (register_module("C1", Conv2d(Conv2dOptions( 3, 64, 3).padding(1))))
, C1B (register_module("C1B", BatchNorm2d(BatchNorm2dOptions(64))))
, C3 (register_module("C3", Conv2d(Conv2dOptions( 64, 64, 3).padding(1))))
, C3B (register_module("C3B", BatchNorm2d(BatchNorm2dOptions(64))))
, C6 (register_module("C6", Conv2d(Conv2dOptions( 64, 128, 3).padding(1))))
, C6B (register_module("C6B", BatchNorm2d(BatchNorm2dOptions(128))))
, C8 (register_module("C8", Conv2d(Conv2dOptions(128, 128, 3).padding(1))))
, C8B (register_module("C8B", BatchNorm2d(BatchNorm2dOptions(128))))
, C11 (register_module("C11", Conv2d(Conv2dOptions(128, 256, 3).padding(1))))
, C11B(register_module("C11B",BatchNorm2d(BatchNorm2dOptions(256))))
, C13 (register_module("C13", Conv2d(Conv2dOptions(256, 256, 3).padding(1))))
, C13B(register_module("C13B",BatchNorm2d(BatchNorm2dOptions(256))))
, C15 (register_module("C15", Conv2d(Conv2dOptions(256, 256, 3).padding(1))))
, C15B(register_module("C15B",BatchNorm2d(BatchNorm2dOptions(256))))
, C18 (register_module("C18", Conv2d(Conv2dOptions(256, 512, 3).padding(1))))
, C18B(register_module("C18B",BatchNorm2d(BatchNorm2dOptions(512))))
, C20 (register_module("C20", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, C20B(register_module("C20B",BatchNorm2d(BatchNorm2dOptions(512))))
, C22 (register_module("C22", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, C22B(register_module("C22B",BatchNorm2d(BatchNorm2dOptions(512))))
, C25 (register_module("C25", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, C25B(register_module("C25B",BatchNorm2d(BatchNorm2dOptions(512))))
, C27 (register_module("C27", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, C27B(register_module("C27B",BatchNorm2d(BatchNorm2dOptions(512))))
, C29 (register_module("C29", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, C29B(register_module("C29B",BatchNorm2d(BatchNorm2dOptions(512))))
, FC32(register_module("FC32",Linear(512 * 7 * 7, 4096)))
, FC35(register_module("FC35",Linear(4096, 4096)))
, FC38(register_module("FC38",Linear(4096, num_classes))){}
外加上每个网络都有一个forward方法用来处理输入张量,和输出张量
// block#1
x = F::max_pool2d(F::relu(C3B(C3(F::relu(C1B(C1(x)),
F::ReLUFuncOptions(true)))),
F::ReLUFuncOptions(true)),
F::MaxPool2dFuncOptions(2));
// block#2
x = F::max_pool2d(F::relu(C8B(C8(F::relu(C6B(C6(x)),
F::ReLUFuncOptions(true)))),
F::ReLUFuncOptions(true)),
F::MaxPool2dFuncOptions(2));
// block#3
x = F::max_pool2d(F::relu(C15B(C15(F::relu(C13B(C13(F::relu(C11B(C11(x)),
F::ReLUFuncOptions(true)))),
F::ReLUFuncOptions(true)))),
F::ReLUFuncOptions(true)),
F::MaxPool2dFuncOptions(2));
// block#4
x = F::max_pool2d(F::relu(C22B(C22(F::relu(C20B(C20(F::relu(C18B(C18(x)),
F::ReLUFuncOptions(true)))),
F::ReLUFuncOptions(true)))),
F::ReLUFuncOptions(true)),
F::MaxPool2dFuncOptions(2));
// block#5
x = F::max_pool2d(F::relu(C29B(C29(F::relu(C27B(C27(F::relu(C25B(C25(x)),
F::ReLUFuncOptions(true)))),
F::ReLUFuncOptions(true)))),
F::ReLUFuncOptions(true)),
F::MaxPool2dFuncOptions(2));
每种网络这么搞一下,光一个VGG要搞12个类,代码量大,而且重复有规律的代码太多,根据"不做第二遍"原则,是时候提炼代码摆脱重复劳动的时候了!
神经网络加载配置文件的设计
这个小标题有点标题党的意思,但是我想目前这个设计至少要照顾到前馈神经网络,从目前的LeNet和VGG来看,这种网络包含两部分:
- 注册神经网络模块,构建基本网路拓补,权重层都在这里定义和注册
- 张量处理拓补,定义处理输入张量的流程
所以这个配置文件设计,主要考虑这两部分,当然这个配置文件还是一个神经网络的数据库,可以把所有流行的神经网络都定义在里面,可以根据网络名称,进行加载。
<?xml version="1.0" encoding="utf-8"?>
<nns>
<nn name="VGGD_NoBatchNorm">
<modules>
<!-- Block#1 -->
<module name="C1" type="conv2d" in_channels="3" out_channels="64" kernel_size="3" padding="1" />
<module name="C3" type="conv2d" in_channels="64" out_channels="64" kernel_size="3" padding="1" />
<!-- Block#2 -->
<module name="C6" type="conv2d" in_channels="64" out_channels="128" kernel_size="3" padding="1" />
<module name="C8" type="conv2d" in_channels="128" out_channels="128" kernel_size="3" padding="1" />
<!-- Block#3 -->
<module name="C11" type="conv2d" in_channels="128" out_channels="256" kernel_size="3" padding="1" />
<module name="C13" type="conv2d" in_channels="256" out_channels="256" kernel_size="3" padding="1" />
<module name="C15" type="conv2d" in_channels="256" out_channels="256" kernel_size="3" padding="1" />
<!-- Block#4 -->
<module name="C18" type="conv2d" in_channels="256" out_channels="512" kernel_size="3" padding="1" />
<module name="C20" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<module name="C22" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<!-- Block#5 -->
<module name="C25" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<module name="C27" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<module name="C29" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<!-- FC -->
<module name="FC32" type="linear" in_features="25088" out_features="4096" />
<module name="FC35" type="linear" in_features="4096" out_features="4096" />
<module name="FC38" type="linear" in_features="4096" out_features="1000" />
</modules>
<forward>
<f module="C1" />
<f functional="relu" inplace="true" />
<f module="C3" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f module="C6" />
<f functional="relu" inplace="true" />
<f module="C8" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f module="C11" />
<f functional="relu" inplace="true" />
<f module="C13" />
<f functional="relu" inplace="true" />
<f module="C15" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f module="C18" />
<f functional="relu" inplace="true" />
<f module="C20" />
<f functional="relu" inplace="true" />
<f module="C22" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f module="C25" />
<f functional="relu" inplace="true" />
<f module="C27" />
<f functional="relu" inplace="true" />
<f module="C29" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f view="flat" />
<f module="FC32" />
<f functional="relu" inplace="true" />
<f functional="dropout" inplace="true" p="0.5" />
<f module="FC35" />
<f functional="relu" inplace="true" />
<f functional="dropout" inplace="true" p="0.5" />
<f module="FC38" />
</forward>
</nn>
<nn name="VGGD_BatchNorm">
<modules>
<!-- Block#1 -->
<module name="C1" type="conv2d" in_channels="3" out_channels="64" kernel_size="3" padding="1" />
<module name="C1B" type="batchnorm2d" num_features="64" />
<module name="C3" type="conv2d" in_channels="64" out_channels="64" kernel_size="3" padding="1" />
<module name="C3B" type="batchnorm2d" num_features="64" />
<!-- Block#2 -->
<module name="C6" type="conv2d" in_channels="64" out_channels="128" kernel_size="3" padding="1" />
<module name="C6B" type="batchnorm2d" num_features="128" />
<module name="C8" type="conv2d" in_channels="128" out_channels="128" kernel_size="3" padding="1" />
<module name="C8B" type="batchnorm2d" num_features="128" />
<!-- Block#3 -->
<module name="C11" type="conv2d" in_channels="128" out_channels="256" kernel_size="3" padding="1" />
<module name="C11B" type="batchnorm2d" num_features="256" />
<module name="C13" type="conv2d" in_channels="256" out_channels="256" kernel_size="3" padding="1" />
<module name="C13B" type="batchnorm2d" num_features="256" />
<module name="C15" type="conv2d" in_channels="256" out_channels="256" kernel_size="3" padding="1" />
<module name="C15B" type="batchnorm2d" num_features="256" />
<!-- Block#4 -->
<module name="C18" type="conv2d" in_channels="256" out_channels="512" kernel_size="3" padding="1" />
<module name="C18B" type="batchnorm2d" num_features="512" />
<module name="C20" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<module name="C20B" type="batchnorm2d" num_features="512" />
<module name="C22" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<module name="C22B" type="batchnorm2d" num_features="512" />
<!-- Block#5 -->
<module name="C25" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<module name="C25B" type="batchnorm2d" num_features="512" />
<module name="C27" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<module name="C27B" type="batchnorm2d" num_features="512" />
<module name="C29" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
<module name="C29B" type="batchnorm2d" num_features="512" />
<!-- FC -->
<module name="FC32" type="linear" in_features="25088" out_features="4096" />
<module name="FC35" type="linear" in_features="4096" out_features="4096" />
<module name="FC38" type="linear" in_features="4096" out_features="1000" />
</modules>
<forward>
<f module="C1" />
<f module="C1B" />
<f functional="relu" inplace="true" />
<f module="C3" />
<f module="C3B" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f module="C6" />
<f module="C6B" />
<f functional="relu" inplace="true" />
<f module="C8" />
<f module="C8B" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f module="C11" />
<f module="C11B" />
<f functional="relu" inplace="true" />
<f module="C13" />
<f module="C13B" />
<f functional="relu" inplace="true" />
<f module="C15" />
<f module="C15B" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f module="C18" />
<f module="C18B" />
<f functional="relu" inplace="true" />
<f module="C20" />
<f module="C20B" />
<f functional="relu" inplace="true" />
<f module="C22" />
<f module="C22B" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f module="C25" />
<f module="C25B" />
<f functional="relu" inplace="true" />
<f module="C27" />
<f module="C27B" />
<f functional="relu" inplace="true" />
<f module="C29" />
<f module="C29B" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
<f view="flat" />
<f module="FC32" />
<f functional="relu" inplace="true" />
<f functional="dropout" inplace="true" p="0.5" />
<f module="FC35" />
<f functional="relu" inplace="true" />
<f functional="dropout" inplace="true" p="0.5" />
<f module="FC38" />
</forward>
</nn>
</nns>
上面这个例子定义两种神经网络VGG-D带batchnorm和不带batchnorm版本的,
<nn name="VGGD_NoBatchNorm">...</nn>
<nn name="VGGD_BatchNorm">...</nn>
每种网络定义了需要注册那些模块
<modules>...</modules>
每个模块定义,模块的名称和类型,用于libtorch注册,同时指定了各种类型的参数,用于各种模块初始化,比如下面,定义了一个名称为C1的二维卷积模块,其参数是输入、输出通道和卷积核大小分别为3、64和3x3
<module name="C1" type="conv2d" in_channels="3" out_channels="64" kernel_size="3" padding="1" />
最后就是定义处理张量的流程,比如下面定义了,张量先经过卷积层C1,然后是batchnorm层C1B,然后经过函数inplace relu(就不为其定义模块了,可以节省内存),然后是卷积层和batchnorm层C3和C3B,再经过inplace relu,然后是池化层处理函数进行sub-sampling:
<forward>
<f module="C1" />
<f module="C1B" />
<f functional="relu" inplace="true" />
<f module="C3" />
<f module="C3B" />
<f functional="relu" inplace="true" />
<f functional="max_pool2d" kernel_size="2" />
...
</forward>
根据配置文件自动加载神经网络
首先引入tinyxml2,其更节省内存,而且只有一个.h/.cpp文件,机器容易引入工程,可以分为如下几个步骤:
- 加载配置文件
int iRet = 0;
if (xmlDoc.LoadFile("nnconfig.xml") != tinyxml2::XML_SUCCESS)
{
printf("Failed to load 'nnconfig.xml'.\n");
return -1;
}
- 根据名称找到对应网络配置信息
if (xmlDoc.RootElement() == NULL ||
xmlDoc.RootElement()->NoChildren() ||
XP_STRICMP(xmlDoc.RootElement()->Name(), "nns") != 0)
{
printf("It is an invalid 'nnconfig.xml' file.\n");
return -1;
}
auto child = xmlDoc.RootElement()->FirstChildElement("nn");
while (child != NULL)
{
if (XP_STRICMP(child->Attribute("name"), szNNName) == 0)
break;
child = child->NextSiblingElement("nn");
}
if (child == NULL)
{
printf("Failed to find the neutral network '%s'.\n", szNNName);
return -1;
}
// Load the modules one by one
auto Elementmodules = child->FirstChildElement("modules");
if (Elementmodules == NULL || Elementmodules->NoChildren())
{
printf("No modules to be loaded.\n");
return -1;
}
auto ElementModule = Elementmodules->FirstChildElement("module");
while (ElementModule != NULL)
{
if (LoadModule(ElementModule) != 0)
{
iRet = -1;
printf("Failed to load the module '%s'.\n",
ElementModule->Attribute("name", "(null)"));
goto done;
}
ElementModule = ElementModule->NextSiblingElement("module");
}
// load the forward list
auto ElementForward = child->FirstChildElement("forward");
if (ElementForward != NULL && !ElementForward->NoChildren())
{
auto f = ElementForward->FirstChildElement("f");
while (f != NULL)
{
forward_list.push_back(f);
f = f->NextSiblingElement();
}
}
nn_model_name = szNNName;
- 加载和注册模块
int iRet = 0;
if (moduleElement == NULL)
return -1;
const char* szModuleType = moduleElement->Attribute("type");
if (szModuleType == NULL)
{
printf("Please specify the module type.\n");
return -1;
}
const char* szModuleName = moduleElement->Attribute("name");
if (szModuleName == NULL)
{
printf("Please specify the module name.\n");
return -1;
}
if (XP_STRICMP(szModuleType, "conv2d") == 0)
{
// extract the attributes of conv2d
int64_t in_channels = moduleElement->Int64Attribute("in_channels", -1LL); assert(in_channels > 0);
int64_t out_channels = moduleElement->Int64Attribute("out_channels", -1LL); assert(out_channels > 0);
int64_t kernel_size = moduleElement->Int64Attribute("kernel_size", -1LL); assert(kernel_size > 0);
int64_t padding = moduleElement->Int64Attribute("padding", 1LL);
std::shared_ptr<torch::nn::Module> spConv2D =
std::make_shared<torch::nn::Conv2dImpl>(
torch::nn::Conv2dOptions(in_channels, out_channels, kernel_size).padding(padding));
nn_modules[szModuleName] = spConv2D;
nn_module_types[szModuleName] = szModuleType;
register_module(szModuleName, spConv2D);
}
else if (XP_STRICMP(szModuleType, "linear") == 0)
{
int64_t in_features = moduleElement->Int64Attribute("in_features", -1LL); assert(in_features > 0);
int64_t out_features = moduleElement->Int64Attribute("out_features", -1LL); assert(out_features > 0);
std::shared_ptr<torch::nn::Module> spLinear =
std::make_shared<torch::nn::LinearImpl>(in_features, out_features);
nn_modules[szModuleName] = spLinear;
register_module(szModuleName, spLinear);
}
else if (XP_STRICMP(szModuleType, "batchnorm2d") == 0)
{
int64_t num_features = moduleElement->Int64Attribute("num_features", -1); assert(num_features > 0);
std::shared_ptr<torch::nn::Module> spBatchNorm2D =
std::make_shared<torch::nn::BatchNorm2dImpl>(num_features);
nn_modules[szModuleName] = spBatchNorm2D;
register_module(szModuleName, spBatchNorm2D);
}
else
{
printf("Failed to load the module with unsupported type: '%s'.\n", szModuleType);
iRet = -1;
}
return iRet;
- 实现通用张量处理
for (auto& f : forward_list)
{
const char* szModuleName = f->Attribute("module");
if (f != NULL)
{
auto m = nn_modules.find(szModuleName);
if (m != nn_modules.end())
{
std::string& module_type = nn_module_types[szModuleName];
if (module_type == "conv2d")
{
auto spConv2d = std::dynamic_pointer_cast<torch::nn::Conv2dImpl>(m->second);
input = spConv2d->forward(input);
continue;
}
else if (module_type == "linear")
{
auto spLinear = std::dynamic_pointer_cast<torch::nn::LinearImpl>(m->second);
input = spLinear->forward(input);
continue;
}
else if (module_type == "batchnorm2d")
{
auto spBatchNorm2d = std::dynamic_pointer_cast<torch::nn::BatchNorm2dImpl>(m->second);
input = spBatchNorm2d->forward(input);
continue;
}
}
}
const char* szFunctional = f->Attribute("functional");
if (szFunctional != NULL)
{
if (XP_STRICMP(szFunctional, "relu") == 0)
{
bool inplace = f->BoolAttribute("inplace", false);
input = F::relu(input, F::ReLUFuncOptions(inplace));
continue;
}
else if (XP_STRICMP(szFunctional, "max_pool2d") == 0)
{
int64_t kernel_size = f->Int64Attribute("kernel_size", 2);
input = F::max_pool2d(input, F::MaxPool2dFuncOptions(kernel_size));
continue;
}
else if (XP_STRICMP(szFunctional, "dropout") == 0)
{
double p = f->DoubleAttribute("p", 0.5);
bool inplace = f->BoolAttribute("inplace", false);
input = F::dropout(input, F::DropoutFuncOptions().p(p).inplace(inplace));
continue;
}
}
const char* szView = f->Attribute("view");
if (szView != NULL)
{
if (XP_STRICMP(szView, "flat"))
{
input = input.view({ input.size(0), -1 });
continue;
}
}
}
测试
当然上面的实现只是一个框架性的实现,目前支持的神经网络模块不是很多,支持张量处理的函数也不多,但是也很容易扩展,后面会根据需要慢慢扩展,目前代码放在Github,可以参见BaseNNet.h/cpp.
写好这个简单加载器,可以用如下代码做一个测试:
BaseNNet bnnt;
bnnt.LoadNet("VGGD_BatchNorm");
bnnt.Print();
然后就看到一个带BatchNorm的VGG-D网络就被加载进来了:
总结
以后可以在配置文件nnconfig.xml中定义神经网络,然后加载它,后面的精力只需放在数据集的组织,train, verify和test的实现。