论文来源:Dynamic Network Surgery for Efficient DNNs
DNS-修剪示意图
具体算法流程:
算法实现分两步:第一不剪掉不重要的连接,第二对可能重要性恢复的连接进行拼接恢复。
这里有一点疑惑:怎样判断被误检而剪掉的连接?(实际上代码中对前一次迭代中已剪枝的掩码有恢复的操作)
具体的上两步实现通过增加掩码矩阵T_k进行,掩码的更新:
// 1、Calculate the mean and standard deviation of learnable parameters
if (this->std==0 && this->iter_==0)
{
unsigned int ncount = 0;
for (unsigned int k = 0;k < this->blobs_[0]->count(); ++k) {
this->mu += fabs(weightMask[k]*weight[k]);
this->std += weightMask[k]*weight[k]*weight[k];
if (weightMask[k]*weight[k]!=0) ncount++;
}
if (this->bias_term_) {
for (unsigned int k = 0;k < this->blobs_[1]->count(); ++k) {
this->mu += fabs(biasMask[k]*bias[k]);
this->std += biasMask[k]*bias[k]*bias[k];
if (biasMask[k]*bias[k]!=0) ncount++;
}
}
this->mu /= ncount; this->std -= ncount*mu*mu;
this->std /= ncount; this->std = sqrt(std);
}
当前层第一次进入训练时,计算该层的均值和均方值:mu与std
// 2、Calculate the weight mask and bias mask with probability
#gamma和power是用来控制更新掩码-T的频率的
#Probability = (1+gamma*iter)^-power
#即原文中所说的一个减函数(控制剪枝后网络收敛)
Dtype r = static_cast<Dtype>(rand())/static_cast<Dtype>(RAND_MAX);
if (pow(1+(this->gamma)*(this->iter_),-(this->power))>r && (this->iter_)<(this->iter_stop_)) {
for (unsigned int k = 0;k < this->blobs_[0]->count(); ++k) {
if (weightMask[k]==1 && fabs(weight[k])<=0.9*std::max(mu+crate*std,Dtype(0)))
weightMask[k] = 0;
else if (weightMask[k]==0 && fabs(weight[k])>1.1*std::max(mu+crate*std,Dtype(0)))
weightMask[k] = 1;
}
条件判断语句if下的关系操作对应原论文中:
通过计算的均值和均方值进行掩码操作(概率),参见源代码中给出的crate=4,掩码矩阵设定了两个阈值(a_k和b_k,其中a_k=0.9*Max(mu+crate*std , 0,b_k=1.1*Max(mu+crate*std , 0))进行更新。
这里解释下两侧不等式联系到标准差和均值的具体含义:
如上图,我们目的在将连接的绝对值在小于一定阈值的情况下进行删减,图中深蓝色区域距离均值处最近,即靠近均值距离的计算=》distance-mu = crate * std,这里crate为我们试图剪枝压缩比率,人为设定的超参数。
代码中(关于1和0的判断)实现了关于拼接-splice的概念的实现。
if (this->bias_term_) {
for (unsigned int k = 0;k < this->blobs_[1]->count(); ++k) {
if (biasMask[k]==1 && fabs(bias[k])<=0.9*std::max(mu+crate*std,Dtype(0)))
biasMask[k] = 0;
else if (biasMask[k]==0 && fabs(bias[k])>1.1*std::max(mu+crate*std,Dtype(0)))
biasMask[k] = 1;
}
}
}
根据上步中的掩码更新规则更新本层的参数后,接着对权值和偏置进行更新。
// 3、Calculate the current (masked) weight and bias
for (unsigned int k = 0;k < this->blobs_[0]->count(); ++k) {
weightTmp[k] = weight[k]*weightMask[k];
}
if (this->bias_term_){
for (unsigned int k = 0;k < this->blobs_[1]->count(); ++k) {
biasTmp[k] = bias[k]*biasMask[k];
}
}
将更新好的权值和偏差带入前向计算中
// 4、Forward calculation with (masked) weight and bias
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->cpu_data();
Dtype* top_data = top[i]->mutable_cpu_data();
for (int n = 0; n < this->num_; ++n) {
this->forward_cpu_gemm(bottom_data + bottom[i]->offset(n), weightTmp,
top_data + top[i]->offset(n));
if (this->bias_term_) {
this->forward_cpu_bias(top_data + top[i]->offset(n), biasTmp);
}
}
}