DNS-动态外科手术关键细节理解(结合caffe源码)

论文来源:Dynamic Network Surgery for Efficient DNNs

                                                       DNS-修剪示意图

具体算法流程:

算法实现分两步:第一不剪掉不重要的连接,第二对可能重要性恢复的连接进行拼接恢复。

这里有一点疑惑:怎样判断被误检而剪掉的连接?(实际上代码中对前一次迭代中已剪枝的掩码有恢复的操作

具体的上两步实现通过增加掩码矩阵T_k进行,掩码的更新:

// 1、Calculate the mean and standard deviation of learnable parameters

if (this->std==0 && this->iter_==0)
{      
	unsigned int ncount = 0;
	for (unsigned int k = 0;k < this->blobs_[0]->count(); ++k) {
		this->mu  += fabs(weightMask[k]*weight[k]);       
		this->std += weightMask[k]*weight[k]*weight[k];
		if (weightMask[k]*weight[k]!=0) ncount++;
	}
	if (this->bias_term_) {
		for (unsigned int k = 0;k < this->blobs_[1]->count(); ++k) {
			this->mu  += fabs(biasMask[k]*bias[k]);
			this->std += biasMask[k]*bias[k]*bias[k];
			if (biasMask[k]*bias[k]!=0) ncount++;
		}       
	}
	this->mu /= ncount; this->std -= ncount*mu*mu; 
	this->std /= ncount; this->std = sqrt(std);
} 

当前层第一次进入训练时,计算该层的均值和均方值:mu与std

// 2、Calculate the weight mask and bias mask with probability

#gamma和power是用来控制更新掩码-T的频率的

#Probability = (1+gamma*iter)^-power

#即原文中所说的一个减函数(控制剪枝后网络收敛)

Dtype r = static_cast<Dtype>(rand())/static_cast<Dtype>(RAND_MAX);
if (pow(1+(this->gamma)*(this->iter_),-(this->power))>r && (this->iter_)<(this->iter_stop_)) { 	
	for (unsigned int k = 0;k < this->blobs_[0]->count(); ++k) {
		if (weightMask[k]==1 && fabs(weight[k])<=0.9*std::max(mu+crate*std,Dtype(0))) 
			weightMask[k] = 0;
		else if (weightMask[k]==0 && fabs(weight[k])>1.1*std::max(mu+crate*std,Dtype(0)))
			weightMask[k] = 1;
	}	

条件判断语句if下的关系操作对应原论文中:

通过计算的均值和均方值进行掩码操作(概率),参见源代码中给出的crate=4,掩码矩阵设定了两个阈值(a_k和b_k,其中a_k=0.9*Max(mu+crate*std , 0,b_k=1.1*Max(mu+crate*std , 0))进行更新。

这里解释下两侧不等式联系到标准差和均值的具体含义:

如上图,我们目的在将连接的绝对值在小于一定阈值的情况下进行删减,图中深蓝色区域距离均值处最近,即靠近均值距离的计算=》distance-mu = crate * std,这里crate为我们试图剪枝压缩比率,人为设定的超参数。

代码中(关于1和0的判断)实现了关于拼接-splice的概念的实现。

	if (this->bias_term_) {       
		for (unsigned int k = 0;k < this->blobs_[1]->count(); ++k) {
			if (biasMask[k]==1 && fabs(bias[k])<=0.9*std::max(mu+crate*std,Dtype(0))) 
				biasMask[k] = 0;
			else if (biasMask[k]==0 && fabs(bias[k])>1.1*std::max(mu+crate*std,Dtype(0)))
				biasMask[k] = 1;
		}    
	} 
}

根据上步中的掩码更新规则更新本层的参数后,接着对权值和偏置进行更新。

// 3、Calculate the current (masked) weight and bias

for (unsigned int k = 0;k < this->blobs_[0]->count(); ++k) {
	weightTmp[k] = weight[k]*weightMask[k];
}
if (this->bias_term_){
	for (unsigned int k = 0;k < this->blobs_[1]->count(); ++k) {
		biasTmp[k] = bias[k]*biasMask[k];
	}
}

将更新好的权值和偏差带入前向计算中

// 4、Forward calculation with (masked) weight and bias

for (int i = 0; i < bottom.size(); ++i) {
	const Dtype* bottom_data = bottom[i]->cpu_data();
	Dtype* top_data = top[i]->mutable_cpu_data();
	for (int n = 0; n < this->num_; ++n) {
	  this->forward_cpu_gemm(bottom_data + bottom[i]->offset(n), weightTmp,
		  top_data + top[i]->offset(n));
	  if (this->bias_term_) {
		this->forward_cpu_bias(top_data + top[i]->offset(n), biasTmp);
	  }
	}
}

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值