【深度学习】Torch卷积层源码详解

原创 2016年08月31日 19:11:57

本文以前向传播为例,详细分析Torch的nn包中,SpatialConvolution函数的实现方式。
在分析源文件时,同时给出了github上的链接以及安装后的文件位置。

初始化

定义一个卷积层需要如下输入参数

nInputPlane\nOutputPlane    -- 输入\输出通道数,M\N
kW\kH                       -- 核尺寸,K
dW\dH                       -- 步长
padW\padH                   -- 补边

卷积层的核心变量

weight         -- 卷积核权重,N*M*K*K
bias           -- 卷积核偏置,N
gradWeight     -- 权重导数,N*M*K*K
gradBia        -- 偏置导数,N

为效率起见,torch的层采用分层方式实现:

nn(lua)->THNN(C)->THTensor(C)->THBlas(C)->LAPACK(Fortran)

nn(lua)层次

/extra/nn/SpatialConvolution.lua中,定义了卷积层的lua接口。

前向运算的函数是updateOutput(input),其中执行运算的部分如下:

input.THNN.SpatialConvolutionMM_updateOutput(
      input:cdata(),            self.output:cdata(),
      self.weight:cdata(),      THNN.optionalTensor(self.bias),
      self.finput:cdata(),      self.fgradInput:cdata(),
      self.kW, self.kH,         self.dW, self.dH,      self.padW, self.padH
   )

其中input.THNN是输入Tensor的一个C接口,传入的参数也都用:cdata()转化成是C类型。

Torch中另有/extra/nn/SpatialConvolutionMM.lua,未在文档中出现,内容几乎相同,不做分别。

THNN库

THNN是一个C库,包含了nn包中的C实现,可以不依赖Lua运行。

/extra/nn/lib/THNN/generic/THNN.h包含了库中函数的声明。

THNN库中大量采用了宏定义的方式来命名,例如:

TH_API void THNN_(SpatialConvolutionMM_updateOutput)(...)
TH_API void THNN_(SpatialConvolutionMM_updateGradInput)(...)

THNN_开头的函数定义在/extra/nn/lib/THNN/generic/目录下,这两个在SpatialConvolutionMM.c文件中。

其他几个库

顺便辨识一下几个容易混淆的库/包:

  • nn(lua)->THNN(C)
  • cunn(lua)->THCUNN(cuda)

在Torch自己的github下维护;
lua文件在/extra/nn/目录下;
C文件在/extra/nn/lib/THNN/generic/目录下,cuda文件在/extra/nn/lib/THCUNN/目录下;
nn中的数据/层通过:cuda()可以转化为cunn中的数据/层;反之,则使用:float()

  • cudnn(lua)->cuDNN库

在Torch的重要作者soumith的gihub下维护;
lua文件在/extra/cudnn/目录下;
实现部分需要安装cuDNN;
nn中的层可以通过cudnn.convert(net,cudnn)转化为cudnn中的层;反之则使用cudnn.convert(net,nn)

THNN(C)层次

/extra/nn/lib/THNN/generic/SpatialConvolutionMM.c实现了卷积层的核心功能。分三步骤实现。

Step 1

首先,把输入的3D或4D的input展开成2D或3D的finput:

THNN_(unfolded_copy)(finput, input, kW, kH, dW, dH, padW, padH, nInputPlane, inputWidth, inputHeight, outputWidth, outputHeight);

THNN_(unfolded_copy)是Torch中的重要函数,在/extra/nn/lib/THNN/generic/unfold.c中定义。

input尺寸为M*H*W。对于每一通道,根据卷积尺寸,将其进行平移,获得K*K个结果。这些结果摞起来得到(M*K*K)*(H*W)的finput

例:设input尺寸为2*4*4,两通道如下
这里写图片描述

使用3*3卷积核时,每一通道共有3*3=9个平移结果。卷积模板9个像素位置对应的平移为:
这里写图片描述
平移1=右移1+下移1:
这里写图片描述

平移4 = 右移1
这里写图片描述

相应地,把卷积权重weight也整理成2D矩阵N*(M*K*K)

Step 2

接下来,创建N*(H*W)的输出矩阵output:

  output2d = THTensor_(newWithStorage2d)(output->storage, output->storageOffset,
                                nOutputPlane, -1, outputHeight*outputWidth, -1);

THTensor_开头的函数在/pkg/torch/lib/TH/generic/THTensor.h中声明。

然后把卷积层的bias逐通道地复制到输出output中。

for(i = 0; i < nOutputPlane; i++)
        THVector_(fill)(output->storage->data+output->storageOffset+output->stride[0]*i, THTensor_(get1d)(bias, i), outputHeight*outputWidth);

相似地,THVector_开头的函数直接在/pkg/torch/lib/TH/generic/THVector.c中声明和定义。

Step 3

平移展开后的输入finput,通过与weight的矩阵乘法,得到N*M*H*W的卷积结果output
这里写图片描述

这一步是卷积的核心,通过/pkg/torch/lib/TH/generic/THTensorMath.c中的函数实现:

THTensor_(addmm)(output2d, 1, output2d, 1, weight, finput);

THTensor层次

许多以THTensor_开头的函数都定义在/pkg/torch/lib/TH/generic/目录下,包括THTensor.c,THTensorConv.c,THTensorRandom.c等。前述矩阵乘法定义在THTensorMath.c中。

经过一系列合法性检查,执行乘法的是一个THBlas_函数:

  THBlas_(gemm)(transpose_m1,
                transpose_m2,
                r__->size[(transpose_r == 'n' ? 0 : 1)],
                r__->size[(transpose_r == 'n' ? 1 : 0)],
                m1_->size[(transpose_r == 'n' ? 1 : 0)],
                alpha,
                THTensor_(data)(m1_),
                (transpose_m1 == 'n' ? m1_->stride[(transpose_r == 'n' ? 1 : 0)] : m1_->stride[(transpose_r == 'n' ? 0 : 1)]),
                THTensor_(data)(m2_),
                (transpose_m2 == 'n' ? m2_->stride[(transpose_r == 'n' ? 1 : 0)] : m2_->stride[(transpose_r == 'n' ? 0 : 1)]),
                beta,
                THTensor_(data)(r__),
                r__->stride[(transpose_r == 'n' ? 1 : 0)]);

其中m1是卷积权重,m2是展开的输入。

THBlas(C)层次

/pkg/torch/lib/TH/generic/THBlas.c包含THBlas_(gemm)的实现。

根据数据的类型(double/float),调用LAPACKdgemm_sgemm_函数:

#if defined(TH_REAL_IS_DOUBLE)
    dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
#else
    sgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
#endif
版权声明:本文为博主原创文章,未经博主允许不得转载。

日常笔记:Lua & Torch

Code Zoo,这个名字取自于深度学习框架 caffe 中著名的模型集合:Caffe Model Zoo。就是想把我平时用 Torch 时做深度学习时,写的 Lua 代码,如脚本工具、Lua 和 T...
  • u010167269
  • u010167269
  • 2016年07月14日 23:22
  • 6639

lua,torch,nn模块入门笔记

最近看到好多论文的神经网络都是用lua基于torch实现的,于是迫不得已学学lua和torch,才能看懂人家的代码。教程首先看教程: Learn Lua in 15minites! Torch 7,H...
  • hejunqing14
  • hejunqing14
  • 2016年08月09日 16:52
  • 11731

深度卷积网络图像风格转移(三)代码分析

理解 Deep Photo Style Transfer源代码
  • cicibabe
  • cicibabe
  • 2017年05月17日 11:37
  • 759

Triplet Loss、Coupled Cluster Loss 探究

因为要区分相似图像,所以研究了一下 Triplet Loss,还有今年 CVPR 的一篇文章:《Deep Relative Distance Learning: Tell the Difference...
  • u010167269
  • u010167269
  • 2016年07月25日 20:49
  • 7208

Torch7学习(七)——Neural-Style代码解析

neural style代码解析Torch的框架
  • Hungryof
  • Hungryof
  • 2016年07月26日 16:52
  • 7331

DL学习笔记【17】nn包中的各位Convolutional layers

本来想写在笔记本上的,然而。。到了图书馆发现没有带笔也是囧。。 用一句喜欢的话开始这篇博文:if you can't explain it simply, you don't understand i...
  • Sun7_She
  • Sun7_She
  • 2017年03月18日 21:20
  • 1959

torch matric operation

linspace(a,b,N)得到一个a到b之间的等差数组,a为起点,b为终点。数组的间隔为(a-b)/(N-1) 另外一个数组的shape为(5, )时,数组的长度为5arr.transpose(...
  • yiqingyang2012
  • yiqingyang2012
  • 2017年01月20日 00:35
  • 372

Triplet Loss、Coupled Cluster Loss 探究

因为要区分相似图像,所以研究了一下 Triplet Loss,还有今年 CVPR 的一篇文章:《Deep Relative Distance Learning: Tell the Difference...
  • u010167269
  • u010167269
  • 2016年07月25日 20:49
  • 7208

深度学习笔记(四)用Torch实现MNIST手写数字识别

本节代码地址: https://github.com/vic-w/torch-practice/tree/master/mnist MNIST是手写数字识别的数据库。在深度学习流行的今天,MNI...
  • revolver
  • revolver
  • 2015年11月06日 13:54
  • 11892

Torch7学习(五)——学习神经网路包的用法(3)

总说这篇博客主要是讲对于只有简单层的神经网络,进行手动挡的训练方法。 以及卷积层模块以及criterion模块。例子Simple Layer的Add层。module = nn.Add(inputDi...
  • Hungryof
  • Hungryof
  • 2016年07月25日 22:12
  • 3511
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:【深度学习】Torch卷积层源码详解
举报原因:
原因补充:

(最多只允许输入30个字)