libtorch学习笔记(9)- 自己实现神经网络加载器

15 篇文章 0 订阅
14 篇文章 11 订阅

神经网络记载器(Neutral Network Loader)

背景

通过前面的学习熟悉了两种经典的前置神经网络,LeNet和VGG,在具体实现过程中,有不少重复代码,随着更多神经网络的引入,以及各种神经网络变种,发现代码非常难以维护,而且可读性不是很高,比如VGG网络:

NetBatchNormNon-BatchNorm
A11
A-LRN11
B11
C11
D11
E11

总共有12中组合,如果按照之前的写法:

VGGNet::VGGNet(int num_classes)
    : C1  (register_module("C1",  Conv2d(Conv2dOptions(  3,  64, 3).padding(1))))
    , C1B (register_module("C1B", BatchNorm2d(BatchNorm2dOptions(64))))
    , C3  (register_module("C3",  Conv2d(Conv2dOptions( 64,  64, 3).padding(1))))
    , C3B (register_module("C3B", BatchNorm2d(BatchNorm2dOptions(64))))
    , C6  (register_module("C6",  Conv2d(Conv2dOptions( 64, 128, 3).padding(1))))
    , C6B (register_module("C6B", BatchNorm2d(BatchNorm2dOptions(128))))
    , C8  (register_module("C8",  Conv2d(Conv2dOptions(128, 128, 3).padding(1))))
    , C8B (register_module("C8B", BatchNorm2d(BatchNorm2dOptions(128))))
    , C11 (register_module("C11", Conv2d(Conv2dOptions(128, 256, 3).padding(1))))
    , C11B(register_module("C11B",BatchNorm2d(BatchNorm2dOptions(256))))
    , C13 (register_module("C13", Conv2d(Conv2dOptions(256, 256, 3).padding(1))))
    , C13B(register_module("C13B",BatchNorm2d(BatchNorm2dOptions(256))))
    , C15 (register_module("C15", Conv2d(Conv2dOptions(256, 256, 3).padding(1))))
    , C15B(register_module("C15B",BatchNorm2d(BatchNorm2dOptions(256))))
    , C18 (register_module("C18", Conv2d(Conv2dOptions(256, 512, 3).padding(1))))
    , C18B(register_module("C18B",BatchNorm2d(BatchNorm2dOptions(512))))
    , C20 (register_module("C20", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
    , C20B(register_module("C20B",BatchNorm2d(BatchNorm2dOptions(512))))
    , C22 (register_module("C22", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
    , C22B(register_module("C22B",BatchNorm2d(BatchNorm2dOptions(512))))
    , C25 (register_module("C25", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
    , C25B(register_module("C25B",BatchNorm2d(BatchNorm2dOptions(512))))
    , C27 (register_module("C27", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
    , C27B(register_module("C27B",BatchNorm2d(BatchNorm2dOptions(512))))
    , C29 (register_module("C29", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
    , C29B(register_module("C29B",BatchNorm2d(BatchNorm2dOptions(512))))
    , FC32(register_module("FC32",Linear(512 * 7 * 7, 4096)))
    , FC35(register_module("FC35",Linear(4096, 4096)))
    , FC38(register_module("FC38",Linear(4096, num_classes))){}

外加上每个网络都有一个forward方法用来处理输入张量,和输出张量

// block#1
x = F::max_pool2d(F::relu(C3B(C3(F::relu(C1B(C1(x)), 
    F::ReLUFuncOptions(true)))), 
    F::ReLUFuncOptions(true)), 
    F::MaxPool2dFuncOptions(2));

// block#2
x = F::max_pool2d(F::relu(C8B(C8(F::relu(C6B(C6(x)), 
    F::ReLUFuncOptions(true)))), 
    F::ReLUFuncOptions(true)), 
    F::MaxPool2dFuncOptions(2));

// block#3
x = F::max_pool2d(F::relu(C15B(C15(F::relu(C13B(C13(F::relu(C11B(C11(x)), 
    F::ReLUFuncOptions(true)))), 
    F::ReLUFuncOptions(true)))), 
    F::ReLUFuncOptions(true)), 
    F::MaxPool2dFuncOptions(2));

// block#4
x = F::max_pool2d(F::relu(C22B(C22(F::relu(C20B(C20(F::relu(C18B(C18(x)), 
    F::ReLUFuncOptions(true)))), 
    F::ReLUFuncOptions(true)))), 
    F::ReLUFuncOptions(true)), 
    F::MaxPool2dFuncOptions(2));

// block#5
x = F::max_pool2d(F::relu(C29B(C29(F::relu(C27B(C27(F::relu(C25B(C25(x)), 
    F::ReLUFuncOptions(true)))), 
    F::ReLUFuncOptions(true)))), 
    F::ReLUFuncOptions(true)), 
    F::MaxPool2dFuncOptions(2));

每种网络这么搞一下,光一个VGG要搞12个类,代码量大,而且重复有规律的代码太多,根据"不做第二遍"原则,是时候提炼代码摆脱重复劳动的时候了!

神经网络加载配置文件的设计

这个小标题有点标题党的意思,但是我想目前这个设计至少要照顾到前馈神经网络,从目前的LeNet和VGG来看,这种网络包含两部分:

  1. 注册神经网络模块,构建基本网路拓补,权重层都在这里定义和注册
  2. 张量处理拓补,定义处理输入张量的流程

所以这个配置文件设计,主要考虑这两部分,当然这个配置文件还是一个神经网络的数据库,可以把所有流行的神经网络都定义在里面,可以根据网络名称,进行加载。

<?xml version="1.0" encoding="utf-8"?>

<nns>
  <nn name="VGGD_NoBatchNorm">
    <modules>
      <!-- Block#1 -->
      <module name="C1"  type="conv2d" in_channels="3"   out_channels="64"  kernel_size="3" padding="1" />
      <module name="C3"  type="conv2d" in_channels="64"  out_channels="64"  kernel_size="3" padding="1" />
      <!-- Block#2 -->
      <module name="C6"  type="conv2d" in_channels="64"  out_channels="128" kernel_size="3" padding="1" />
      <module name="C8"  type="conv2d" in_channels="128" out_channels="128" kernel_size="3" padding="1" />
      <!-- Block#3 -->
      <module name="C11" type="conv2d" in_channels="128" out_channels="256" kernel_size="3" padding="1" />
      <module name="C13" type="conv2d" in_channels="256" out_channels="256" kernel_size="3" padding="1" />
      <module name="C15" type="conv2d" in_channels="256" out_channels="256" kernel_size="3" padding="1" />
      <!-- Block#4 -->
      <module name="C18" type="conv2d" in_channels="256" out_channels="512" kernel_size="3" padding="1" />
      <module name="C20" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <module name="C22" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <!-- Block#5 -->
      <module name="C25" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <module name="C27" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <module name="C29" type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <!-- FC -->
      <module name="FC32" type="linear" in_features="25088" out_features="4096" />
      <module name="FC35" type="linear" in_features="4096"  out_features="4096" />
      <module name="FC38" type="linear" in_features="4096"  out_features="1000" />
    </modules>
    <forward>
      <f module="C1" />
      <f functional="relu" inplace="true" />
      <f module="C3" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f module="C6" />
      <f functional="relu" inplace="true" />
      <f module="C8" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f module="C11" />
      <f functional="relu" inplace="true" />
      <f module="C13" />
      <f functional="relu" inplace="true" />
      <f module="C15" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f module="C18" />
      <f functional="relu" inplace="true" />
      <f module="C20" />
      <f functional="relu" inplace="true" />
      <f module="C22" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f module="C25" />
      <f functional="relu" inplace="true" />
      <f module="C27" />
      <f functional="relu" inplace="true" />
      <f module="C29" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f view="flat" />
      <f module="FC32" />
      <f functional="relu" inplace="true" />
      <f functional="dropout" inplace="true" p="0.5" />
      <f module="FC35" />
      <f functional="relu" inplace="true" />
      <f functional="dropout" inplace="true" p="0.5" />
      <f module="FC38" />
    </forward>
  </nn>
  <nn name="VGGD_BatchNorm">
    <modules>
      <!-- Block#1 -->
      <module name="C1"   type="conv2d" in_channels="3"   out_channels="64"  kernel_size="3" padding="1" />
      <module name="C1B"  type="batchnorm2d" num_features="64" />
      <module name="C3"   type="conv2d" in_channels="64"  out_channels="64"  kernel_size="3" padding="1" />
      <module name="C3B"  type="batchnorm2d" num_features="64" />
      <!-- Block#2 -->
      <module name="C6"   type="conv2d" in_channels="64"  out_channels="128" kernel_size="3" padding="1" />
      <module name="C6B"  type="batchnorm2d" num_features="128" />
      <module name="C8"   type="conv2d" in_channels="128" out_channels="128" kernel_size="3" padding="1" />
      <module name="C8B"  type="batchnorm2d" num_features="128" />
      <!-- Block#3 -->
      <module name="C11"  type="conv2d" in_channels="128" out_channels="256" kernel_size="3" padding="1" />
      <module name="C11B" type="batchnorm2d" num_features="256" />
      <module name="C13"  type="conv2d" in_channels="256" out_channels="256" kernel_size="3" padding="1" />
      <module name="C13B" type="batchnorm2d" num_features="256" />
      <module name="C15"  type="conv2d" in_channels="256" out_channels="256" kernel_size="3" padding="1" />
      <module name="C15B" type="batchnorm2d" num_features="256" />
      <!-- Block#4 -->
      <module name="C18"  type="conv2d" in_channels="256" out_channels="512" kernel_size="3" padding="1" />
      <module name="C18B" type="batchnorm2d" num_features="512" />
      <module name="C20"  type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <module name="C20B" type="batchnorm2d" num_features="512" />
      <module name="C22"  type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <module name="C22B" type="batchnorm2d" num_features="512" />
      <!-- Block#5 -->
      <module name="C25"  type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <module name="C25B" type="batchnorm2d" num_features="512" />
      <module name="C27"  type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <module name="C27B" type="batchnorm2d" num_features="512" />
      <module name="C29"  type="conv2d" in_channels="512" out_channels="512" kernel_size="3" padding="1" />
      <module name="C29B" type="batchnorm2d" num_features="512" />
      <!-- FC -->
      <module name="FC32" type="linear" in_features="25088" out_features="4096" />
      <module name="FC35" type="linear" in_features="4096"  out_features="4096" />
      <module name="FC38" type="linear" in_features="4096"  out_features="1000" />
    </modules>
    <forward>
      <f module="C1" />
      <f module="C1B" />
      <f functional="relu" inplace="true" />
      <f module="C3" />
      <f module="C3B" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f module="C6" />
      <f module="C6B" />
      <f functional="relu" inplace="true" />
      <f module="C8" />
      <f module="C8B" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f module="C11" />
      <f module="C11B" />
      <f functional="relu" inplace="true" />
      <f module="C13" />
      <f module="C13B" />
      <f functional="relu" inplace="true" />
      <f module="C15" />
      <f module="C15B" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f module="C18" />
      <f module="C18B" />
      <f functional="relu" inplace="true" />
      <f module="C20" />
      <f module="C20B" />
      <f functional="relu" inplace="true" />
      <f module="C22" />
      <f module="C22B" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f module="C25" />
      <f module="C25B" />
      <f functional="relu" inplace="true" />
      <f module="C27" />
      <f module="C27B" />
      <f functional="relu" inplace="true" />
      <f module="C29" />
      <f module="C29B" />
      <f functional="relu" inplace="true" />
      <f functional="max_pool2d" kernel_size="2" />
      <f view="flat" />
      <f module="FC32" />
      <f functional="relu" inplace="true" />
      <f functional="dropout" inplace="true" p="0.5" />
      <f module="FC35" />
      <f functional="relu" inplace="true" />
      <f functional="dropout" inplace="true" p="0.5" />
      <f module="FC38" />
    </forward>
  </nn> 
</nns>

上面这个例子定义两种神经网络VGG-D带batchnorm和不带batchnorm版本的,

<nn name="VGGD_NoBatchNorm">...</nn>
<nn name="VGGD_BatchNorm">...</nn>

每种网络定义了需要注册那些模块

<modules>...</modules>

每个模块定义,模块的名称和类型,用于libtorch注册,同时指定了各种类型的参数,用于各种模块初始化,比如下面,定义了一个名称为C1的二维卷积模块,其参数是输入、输出通道和卷积核大小分别为3、64和3x3

<module name="C1" type="conv2d" in_channels="3" out_channels="64" kernel_size="3" padding="1" />

最后就是定义处理张量的流程,比如下面定义了,张量先经过卷积层C1,然后是batchnorm层C1B,然后经过函数inplace relu(就不为其定义模块了,可以节省内存),然后是卷积层和batchnorm层C3和C3B,再经过inplace relu,然后是池化层处理函数进行sub-sampling:

<forward>
  <f module="C1" />
  <f module="C1B" />
  <f functional="relu" inplace="true" />
  <f module="C3" />
  <f module="C3B" />
  <f functional="relu" inplace="true" />
  <f functional="max_pool2d" kernel_size="2" />
  ...
</forward>

根据配置文件自动加载神经网络

首先引入tinyxml2,其更节省内存,而且只有一个.h/.cpp文件,机器容易引入工程,可以分为如下几个步骤:

  • 加载配置文件
int iRet = 0;
if (xmlDoc.LoadFile("nnconfig.xml") != tinyxml2::XML_SUCCESS)
{
	printf("Failed to load 'nnconfig.xml'.\n");
	return -1;
}
  • 根据名称找到对应网络配置信息
if (xmlDoc.RootElement() == NULL || 
    xmlDoc.RootElement()->NoChildren() || 
    XP_STRICMP(xmlDoc.RootElement()->Name(), "nns") != 0)
{
    printf("It is an invalid 'nnconfig.xml' file.\n");
    return -1;
}

auto child = xmlDoc.RootElement()->FirstChildElement("nn");
while (child != NULL)
{
    if (XP_STRICMP(child->Attribute("name"), szNNName) == 0)
        break;
    child = child->NextSiblingElement("nn");
}

if (child == NULL)
{
    printf("Failed to find the neutral network '%s'.\n", szNNName);
    return -1;
}

// Load the modules one by one
auto Elementmodules = child->FirstChildElement("modules");
if (Elementmodules == NULL || Elementmodules->NoChildren())
{
    printf("No modules to be loaded.\n");
    return -1;
}

auto ElementModule = Elementmodules->FirstChildElement("module");
while (ElementModule != NULL)
{
    if (LoadModule(ElementModule) != 0)
    {
        iRet = -1;
        printf("Failed to load the module '%s'.\n", 
            ElementModule->Attribute("name", "(null)"));
        goto done;
    }
    ElementModule = ElementModule->NextSiblingElement("module");
}

// load the forward list
auto ElementForward = child->FirstChildElement("forward");
if (ElementForward != NULL && !ElementForward->NoChildren())
{
    auto f = ElementForward->FirstChildElement("f");
    while (f != NULL)
    {
        forward_list.push_back(f);
        f = f->NextSiblingElement();
    }
}

nn_model_name = szNNName;
  • 加载和注册模块
int iRet = 0;
if (moduleElement == NULL)
    return -1;

const char* szModuleType = moduleElement->Attribute("type");
if (szModuleType == NULL)
{
    printf("Please specify the module type.\n");
    return -1;
}

const char* szModuleName = moduleElement->Attribute("name");
if (szModuleName == NULL)
{
    printf("Please specify the module name.\n");
    return -1;
}

if (XP_STRICMP(szModuleType, "conv2d") == 0)
{
    // extract the attributes of conv2d
    int64_t in_channels = moduleElement->Int64Attribute("in_channels", -1LL); assert(in_channels > 0);
    int64_t out_channels = moduleElement->Int64Attribute("out_channels", -1LL); assert(out_channels > 0);
    int64_t kernel_size = moduleElement->Int64Attribute("kernel_size", -1LL); assert(kernel_size > 0);
    int64_t padding = moduleElement->Int64Attribute("padding", 1LL);

    std::shared_ptr<torch::nn::Module> spConv2D = 
        std::make_shared<torch::nn::Conv2dImpl>(
            torch::nn::Conv2dOptions(in_channels, out_channels, kernel_size).padding(padding));
    nn_modules[szModuleName] = spConv2D;
    nn_module_types[szModuleName] = szModuleType;
    register_module(szModuleName, spConv2D);
}
else if (XP_STRICMP(szModuleType, "linear") == 0)
{
    int64_t in_features = moduleElement->Int64Attribute("in_features", -1LL); assert(in_features > 0);
    int64_t out_features = moduleElement->Int64Attribute("out_features", -1LL); assert(out_features > 0);

    std::shared_ptr<torch::nn::Module> spLinear =
        std::make_shared<torch::nn::LinearImpl>(in_features, out_features);
    nn_modules[szModuleName] = spLinear;
    register_module(szModuleName, spLinear);
}
else if (XP_STRICMP(szModuleType, "batchnorm2d") == 0)
{
    int64_t num_features = moduleElement->Int64Attribute("num_features", -1); assert(num_features > 0);

    std::shared_ptr<torch::nn::Module> spBatchNorm2D =
        std::make_shared<torch::nn::BatchNorm2dImpl>(num_features);
    nn_modules[szModuleName] = spBatchNorm2D;
    register_module(szModuleName, spBatchNorm2D);
}
else
{
    printf("Failed to load the module with unsupported type: '%s'.\n", szModuleType);
    iRet = -1;
}

return iRet;
  • 实现通用张量处理
for (auto& f : forward_list)
{
    const char* szModuleName = f->Attribute("module");
    if (f != NULL)
    {
        auto m = nn_modules.find(szModuleName);
        if (m != nn_modules.end())
        {
            std::string& module_type = nn_module_types[szModuleName];
            if (module_type == "conv2d")
            {
                auto spConv2d = std::dynamic_pointer_cast<torch::nn::Conv2dImpl>(m->second);
                input = spConv2d->forward(input);
                continue;
            }
            else if (module_type == "linear")
            {
                auto spLinear = std::dynamic_pointer_cast<torch::nn::LinearImpl>(m->second);
                input = spLinear->forward(input);
                continue;
            }
            else if (module_type == "batchnorm2d")
            {
                auto spBatchNorm2d = std::dynamic_pointer_cast<torch::nn::BatchNorm2dImpl>(m->second);
                input = spBatchNorm2d->forward(input);
                continue;
            }
        }
    }

    const char* szFunctional = f->Attribute("functional");
    if (szFunctional != NULL)
    {
        if (XP_STRICMP(szFunctional, "relu") == 0)
        {
            bool inplace = f->BoolAttribute("inplace", false);
            input = F::relu(input, F::ReLUFuncOptions(inplace));
            continue;
        }
        else if (XP_STRICMP(szFunctional, "max_pool2d") == 0)
        {
            int64_t kernel_size = f->Int64Attribute("kernel_size", 2);
            input = F::max_pool2d(input, F::MaxPool2dFuncOptions(kernel_size));
            continue;
        }
        else if (XP_STRICMP(szFunctional, "dropout") == 0)
        {
            double p = f->DoubleAttribute("p", 0.5);
            bool inplace = f->BoolAttribute("inplace", false);
            input = F::dropout(input, F::DropoutFuncOptions().p(p).inplace(inplace));
            continue;
        }
    }

    const char* szView = f->Attribute("view");
    if (szView != NULL)
    {
        if (XP_STRICMP(szView, "flat"))
        {
            input = input.view({ input.size(0), -1 });
            continue;
        }
    }
}

测试

当然上面的实现只是一个框架性的实现,目前支持的神经网络模块不是很多,支持张量处理的函数也不多,但是也很容易扩展,后面会根据需要慢慢扩展,目前代码放在Github,可以参见BaseNNet.h/cpp.
写好这个简单加载器,可以用如下代码做一个测试:

	BaseNNet bnnt;
	bnnt.LoadNet("VGGD_BatchNorm");

	bnnt.Print();

然后就看到一个带BatchNorm的VGG-D网络就被加载进来了:

在这里插入图片描述

总结

以后可以在配置文件nnconfig.xml中定义神经网络,然后加载它,后面的精力只需放在数据集的组织,train, verify和test的实现。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值