本文首先介绍了计算图的自动求导方法,然后对卷积运算中Kernel和Input的梯度进行了推导,之后基于Pytorch实现了卷积算子并做了正确性检验。
本文主要有两个目的:
- 推导卷积运算各个变量的梯度公式;
- 学习如何扩展Pytorch算子,自己实现了一个能够forward和backward的卷积算子;
首先介绍了计算图的自动求导方法,然后对卷积运算中Kernel和Input的梯度进行了推导,之后基于Pytorch实现了卷积算子并做了正确性检验。
本文的代码在这个GitHub仓库:https://github.com/dragonylee/myDL/blob/master/%E6%89%A9%E5%B1%95%E6%B5%8B%E8%AF%95.ipynb。
计算图
计算图(Computational Graphs)是torch.autograd
自动求导的理论基础,描述为一个有向无环图(DAG),箭头的方向是前向传播(forward)的方向,而逆向的反向传播(backward)的过程可以很方便地对任意变量求偏导。为了方便说明,这里举一个简单的例子:
在Pytorch(Python)里定义上述三个函数:
然后用torchviz
可视化其复合函数的计算图:
得到如下结果:
忽略“Accumulate”这个操作,在该计算图上的反向求导过程表示如下:
这很清晰地展示了计算图的功能,它记录了每一个变量(包括输出、中间变量)的计算函数(可以称之为一个算子,就是图中的方框,入边是输入,出边是输出),从而可以数值计算出相应的导数。实际上,任何变量qqq对ppp求导都可以对两者之间的反向链路进行累乘得到。
输出结果和上图的计算结果一致。注意在backward过程中非叶子节点可以调用.retain_grad()
来记录grad。
以前我一直以为自动求导是一个很复杂的操作,没想到一个计算图就非常简洁地实现了,才发现“我以为”的复杂操作其实是形式化的求导……
卷积运算与梯度推导
本文所涉及的卷积运算是最平凡的卷积运算,不包含stride, padding, dilation, bias等。定义卷积运算
如何实现卷积?
在这里就可以直接用Einstein求和标记将卷积运算写出来了:
代码为
如何计算梯度?
这部分求导的推导是我自己在草稿纸上完成的,后面经过一些验证应该或许可以保证是正确的。
为了方便推导,先不考虑batch和channel,也就是Input, Kernel, Output都是二维的。
Kernel的梯度
根据链式求导法则我们可以将此导数(偏导)写作
那么问题就是求Output对Kernel的偏导,我们用一个简单的例子来推导:
Input的梯度
同样,Input的梯度可以写作
自定义卷积算子
本文的一个很大目的,就是让我自己学会怎么扩展Pytorch的算子,从官方文档了解到,需要实现一个继承torch.autograd.Function
的函数,并且实现forward
和backward
静态函数,才能适应Pytorch的自动求导框架,有一些需要注意的细节:
forward
和backward
函数的第一个参数都是ctx
,就是context的意思,与self
类似,一般如果在backward过程中要用到forward的参数,在forward时就要调用ctx.save_for_backward()
保存起来;forward
有多少个输入,backward
就要有多少个输出,这个看计算图就能明白了,如果不需要求梯度的入边,可以返回None
;
梯度求解
前面在定义卷积运算时,都是考虑了Batch和Channel的,而在推导对Input和Kernel的梯度时,却为了方便没有考虑这两个参数。实际上在实现时,要特别注意每个数据的view
的每个维度之间的关系。
例如我这里定义的:
求Input的梯度也是类似。
代码
正确性验证
torch.autograd.gradcheck
提供了检验梯度运算正确性的工具,它的原理是,给定输入,用你写的算子的backward计算一个output和input的雅各比矩阵,然后再用有限差分的方法计算一个数值解,然后对比这两个结果是否一致。
验证上面的MyConv2dFunc
算子的正确性:
输出为True
。
自定义卷积层模型
需要继承nn.Module
,并且用nn.Parameter
保存权重,也就是卷积核。还要实现forward
方法。
基于MNIST的测试
使用的卷积神经网络模型为LeNet:
任务是对MNIST手写体数字进行分类。
首先用Pytorch自带的Conv、Linear这些网络层搭建然后训练,然后把网络中的Conv2d
替换为我写的MyConv2d
做同样的训练,得到的结果如下(5个epoch, CUDA):
Accuracy | time cost(s) | |
nn.Conv2d | 99.2% | 33.72 |
MyConv2d | 99.1% | 76.49 |