剪枝论文五(Connection pruning)

本文介绍一种经典的剪枝方法(Connection pruning)

1. 核心思想

在这里插入图片描述
步骤:

  1. 通过正常的网络训练学习连通性。然而,与传统训练不同的是,我们不是在学习权重的最终值,而是在学习哪些连接是重要的。
  2. 修剪低重量连接。所有权值低于阈值的连接都被从网络中删除,将一个密集网络转换为一个稀疏网络。
  3. 训练网络学习剩余稀疏连接的最终权值。

2. 算法

  1. 记录卷积层的所有权重
    # 统计卷积层的通道数
    total = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            total += m.weight.data.numel()

    # conv_weights为长度为total的一维向量,记录所有卷积层的权重
    conv_weights = torch.zeros(total).cuda()
    index = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            size = m.weight.data.numel()
            conv_weights[index:(index+size)] = m.weight.data.view(-1).abs().clone()
            index += size
  1. 标记权值小于阈值的部分权重
    thre_index = int(total * args.percent)
    thre = y[thre_index]
    pruned = 0
    print('Pruning threshold: {}'.format(thre))
    zero_flag = False
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m.weight.data.abs().clone()

            # 权值小于thre的mask=0(裁剪), 其余为1(保留)
            mask = weight_copy.gt(thre).float().cuda()

            # 被裁剪的神经元数量 = 之前层的数量 + 这一层的总数 - 这一层的保留数量
            pruned = pruned + mask.numel() - torch.sum(mask)

            # 权重相乘
            m.weight.data.mul_(mask)

            # zero_flag记录这一层是否被全部裁剪
            if int(torch.sum(mask)) == 0:
                zero_flag = True

这里只是理论上mask掉了,实际运算量没有减少,但是源代码只给到了这里。

mask之后的步骤和我写的前几篇差不多,可以参考剪枝论文二(Filters Pruning)

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值