Title | Content |
---|---|
原文 | Binarized Neural Networks: Training Neural Networks with Weights and Activations Constrained to +1 or −1 |
作者 | Matthieu Courbariaux |
说明 | 只包含训练过程的内容学习 |
摘要
- 本文介绍一种训练 BNNs 的方法,也就是运行时训练二值weights和activations的神经网络
- 在forward pass期间,BNNs能够急剧的降低内存和访问,且用按位操作代替大部分的算数运算
- 训练和运行本文的BNN可以实现on-line
简介
DNNs很强大,但是在目标低耗设备上运行很有挑战
文章贡献:
(1)提出训练BNNs的方法:
NN在运行时、训练时计算parameters gradients时,是binary的weights和activations
(2)2个sets的实验,在不同的framework上实现
(3)发现在forward pass阶段(也就是运行和训练阶段),BNNs能够急剧的降低内存和访问,且用按位操作代替大部分的算数运算。
1. BNNs
1.1. 确定性 vs 随机性二值化
训练BNN时,限制weights和activations为+1或-1
本文用两个不同的二值化函数
(1)确定性函数:
xb=Sign(x)={+1,−1,if x≥0otherwise
xb
是二值变量(weight or activation)
x
是实值变量
(2)随机函数:
δ 是 hard sigmoid 函数
随机函数优于确定性函数
难实现,因为要求硬件在quantizing时产生随机位
所以实际常用确定性函数
1.2. 梯度计算与累积
虽然BNN训练用二进制weights和activation来计算parameters gradients,weights的实值梯度会在实值变量当中累积
随机梯度下降(SGD)要求实值weights
accumulators的分辨率很重要
在weights和activations计算parameters gradients时,加入噪声可以更好的帮助泛化
算法1:训练BNN
(1)说明
C
:mini-batch的成本函数
L
:层数
Binarize()函数:表示如何二值化weights or activations
Clip()函数:修改weight,因为实值weight可能会变得很大,但是却对bianry weight没影响
BatchNorm()函数:表示如何batch normalize activations,用batch normalization or 算法3的基于shift的方法
BackBatchNorm()函数:表示如何通过正则化反向传播
Updata()函数:表示当梯度已知时如何更新parameters,用ADAM或算法4的基于shift的方法
(2)算法内容
Require:
minibatch的输入与目标:
(a0,a∗)
初始化权重系数:
γ
前一个weight:
W
前一个BatchNorm参数:
前一个学习率:
η
Ensure:
更新权重:
Wt+1
更新BatchNorm参数:
θt+1
更新学习率:
γt+1
计算parameters gradients:
前向传播:
for k=1 to L do
Wbk←Binarize(Wk)
\\每层权重二值化
sk←abk−1Wbk
\\每层输入乘以权重,得到一个输出
ak←BatchNorm(sk,θk)
\\上一步的输出,激活后变成输入
if k<L then
abk←Binarize(ak)
end if
end for
后向传播:
\\梯度不是二进制的
计算 straight-through estimator
……
(公式好复杂呀~懒得敲了,还是看paper吧~)
1.3. 通过离散传递梯度
sign函数的导数为0,明显不适合返现传播,因为在离散前,关于quantities的cost梯度为0 ???
前人研究在随机离散神经元间estimating or propagating梯度时,发现“straight-through estimator” 是最最快的训练方式
本文采用“straight-through estimator”变体
(1)考虑饱和效应
(2)用确定性的sampling of the bit
对于隐藏层单元
(1)用sign函数的非线性获得二值activations
(2)weights需要结合两部分:
1)限制实值weight在-1和1间,当weight的更新使得
wr
超出
[−1,1]
则直接映射成-1 or 1
2)当使用
wr
时,用
wb=Sign(wr)
做quantize
算法2
1.4. Shift based Batch Normalization
Batch Normalization(BN)可以:
(1)加速训练
(2)减少weights scale的全局影响
正则化噪声可以帮助调整模型
但是在训练时,BN需要很多乘法
所以采用shift based BN(SBN)
SBN近似BN,却几乎没有乘法,且实验没有丢失精度
算法3
1.5. Shift based AdaMax
ADAM learning rule 也可以减少weight scale的影响
但是在训练时,ADAM需要很多乘法
所以采用Shift based AdaMax
实验没有丢失精度
算法4
1.6. 第一层
一个BNN中,只有weights和activations的二值化值被用到所有的计算中
由于每一层的输出是下一层的输入,所以除了第一层,所有层的输入是二进制的
算法5