论文复现:Learning Efficient Convolutional Networks through Network Slimming

论文核心

论文提出了一种结构化剪枝策略,剪枝对象为 channel ,对 channel 重要性的评价标准使用的是 Batch Normalization 层中的缩放因子,这不会给网络带来额外的开销。

在这里插入图片描述


论文细节品读

L 1 L1 L1正则的损失函数:
首先得了解 L 1 L1 L1正则为何能带来稀疏性,相关解释链接
于是论文作者为了诱导 B N BN BN层缩放因子 γ \gamma γ产生稀疏性,对 B N BN BN层的 γ \gamma γ使用 L 1 L1 L1正则,于是更新后的损失函数如下:
L = ∑ ( x , y ) l ( f ( x , W ) , y ) + λ ∑ γ ∈ Γ g ( γ ) { L=\sum\limits_{(x,y)}l(f(x,W),y)+\lambda\sum\limits_{\gamma\in\Gamma}g(\gamma) } L=(x,y)l(f(x,W),y)+λγΓg(γ)
而这多出的 L 1 L1 L1正则化项不是处处可导的,反向传播时需要把该部分单独处理。这在论文复现部分讨论。

经典三步走:
同样采用了这里的三步走方式以获取最大剪枝率和精度,这里特点是在训练反向传播过程中加入了对 γ \gamma γ的稀疏诱导。
在这里插入图片描述


论文复现

准备:
模型选择resnet18,优化器选择 SGD,等等。保证和上个论文复现实验基本条件一致。上篇论文复现
γ \gamma γ处理方式:
首先对上面内容填坑,给出论文作者是如何处理 L 1 L1 L1正则化下项无法求导(严格的说是不能处处求导,在 x = 0 x=0 x=0处无法求导)从而无法使用传统的梯度下降法的。下面是源码部分:

def updateBN():
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.weight.grad.data.add_(args.s*torch.sign(m.weight.data))  # L1

BN 层中先对 γ \gamma γ求导,也就是 torch.sign(m.weight.data),其实求导的值只有0,1,-1三个。然后乘以一个很小的系数,一般选择0.0001,最后再将该部分的值加入到上一次的 γ \gamma γ导数值之中。这个过程在反向传播。

data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        # 反向传播时更新γ的梯度值
        if args.sr:
            updateBN()
        optimizer.step()

Channel 剪枝:
源码是先统计所有 feature map 的总 channel 数,也就是 γ \gamma γ总个数。
由于源码给出的是 VGG 网络的剪枝,而我实验的网络为 resnet18,其中存在 Shoutcut 结构,因此不能像 VGG 一样无脑的统计所有 channel 数, 需要特殊的处理方式。因为论文中也没有提到对 Shoutcut 的特殊处理方式,所以这里就自由发挥了。
为了简化实验,我选择将 Shoutcut 连接的 feature map 不做剪枝处理,这实际是只对8个 feature map 剪枝。下图中被红色框框选的 block 是我要剪枝的目标。
在这里插入图片描述

下面是我关键思路的代码,这部分代码参杂较多个人修改的东西,如有不恰当的地方,请指正

# channel 剪枝 --- Learning Efficient Convolutional Networks through Network Slimming
def prune_channel(model, prune_rates):
    total = 0
    count = 0
    # 和shortcut不相关的block,会被裁剪
    prune_block = [1, 3, 5, 8, 10, 13, 15, 18]
    # basicblock 中和 shoutcut关联的block
    block_basic_sc_connect = [2, 4, 6, 9, 11, 14, 16, 19]
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            if prune_block.count(count) == 1:
                total += m.weight.data.shape[0]
            count += 1
    bn = torch.zeros(total)

    index = 0
    count = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            if prune_block.count(count) == 1:
                size = m.weight.data.shape[0]
      
  • 8
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值