根据BN裁剪模型channel

import sys                                                                                                    
import caffe                                                                                                                                     
import numpy as np                                                       
                                                                         
percent = 0.5                                                            
np.set_printoptions(threshold=sys.maxsize)                               
                                                                         
MODEL_FILE = '********.prototxt' #模型文件                           
PRETRAIN_FILE = '*********.caffemodel'   #参数文件                 
                                                                         
params_txt = 'params.txt'   #保存参数文件                                             
pf = open(params_txt, 'w')                                               
                                                                         
total = 0  #统计BN参数的个数                                                           
                                                                         
net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)   #加载模型和参数                
print("load model successful")                                           
for param_name in net.params.keys():                                     
    weight = net.params[param_name][0].data                              
    bias = net.params[param_name][1].data                                
    if 'bn' in param_name:                                               
        total += weight.shape[0]                                         
print("the number of bn is :", total)                                    
bn = []                                                                  
index = 0                                                                
small_num = 0                                                            
for param_name in net.params.keys():                                     
    weight = net.params[param_name][0].data  #gama                       
    bias = net.params[param_name][1].data  #beta                            
    if 'bn' in param_name:                                               
        size = weight.shape[0]                                           
        weight.shape = (-1, 1)                                           
        for w in weight:                                                 
            bn.append(abs(w))  
                                          
y = sorted(bn)    #根据BN 的gama值进行排序                                                       
thre_index = int(total * percent)    #根据设置的裁剪比例来确定gama阈值

print("thread is :", thre)                                               │
# 这里并没有裁剪,只是对要裁剪的gama置0,测试了一下结果 
# 可能会出现某些layer的channel被全部裁剪的情况,对于这种情况,保留其gama绝对值最大的那个channel                                                                       
pruned = 0                                                               │
cfg = []                                                                 │
cfg_mask = []                                                            │
for param_name in net.params.keys():                                     │
    weight = net.params[param_name][0].data                              │
    bias = net.params[param_name][1].data                                │
    if 'bn' in param_name:                                               │
        size = weight.shape[0]                                           │
        print(param_name)                                                │
        num = 0                                                          │
        index = 0                                                        │
        max_id = 0                                                       │
        max_gama = 0                                                     │
        max_b = 0                                                        │
                                                                         │
        for w in weight:                                                 │
            if abs(w) > max_gama:                                        │
                max_id = index                                           │
                max_gama = w                                             │
                max_b = net.params[param_name][1].data[index]            │
                                                                         │
            if abs(w) < thre:                                            │
                num +=1                                                  │
                small_num +=1                                            │
                net.params[param_name][0].data[index] = 0                │
                net.params[param_name][1].data[index] = 0                │
                                                                         │
            index +=1                                                    │
        if num == size:                                                  │
            net.params[param_name][0].data[max_id] = max_gama            │
            net.params[param_name][1].data[max_id] = max_b               │
                                                                         │
        print(size,num)                                                  │
print(small_num)#    

net.save('at_least_one_new.caffemodel')       

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值