本文介绍一种经典的剪枝方法(Connection pruning)
1. 核心思想
步骤:
- 通过正常的网络训练学习连通性。然而,与传统训练不同的是,我们不是在学习权重的最终值,而是在学习哪些连接是重要的。
- 修剪低重量连接。所有权值低于阈值的连接都被从网络中删除,将一个密集网络转换为一个稀疏网络。
- 训练网络学习剩余稀疏连接的最终权值。
2. 算法
- 记录卷积层的所有权重
# 统计卷积层的通道数
total = 0
for m in model.modules():
if isinstance(m, nn.Conv2d):
total += m.weight.data.numel()
# conv_weights为长度为total的一维向量,记录所有卷积层的权重
conv_weights = torch.zeros(total).cuda()
index = 0
for m in model.modules():
if isinstance(m, nn.Conv2d):
size = m.weight.data.numel()
conv_weights[index:(index+size)] = m.weight.data.view(-1).abs().clone()
index += size
- 标记权值小于阈值的部分权重
thre_index = int(total * args.percent)
thre = y[thre_index]
pruned = 0
print('Pruning threshold: {}'.format(thre))
zero_flag = False
for k, m in enumerate(model.modules()):
if isinstance(m, nn.Conv2d):
weight_copy = m.weight.data.abs().clone()
# 权值小于thre的mask=0(裁剪), 其余为1(保留)
mask = weight_copy.gt(thre).float().cuda()
# 被裁剪的神经元数量 = 之前层的数量 + 这一层的总数 - 这一层的保留数量
pruned = pruned + mask.numel() - torch.sum(mask)
# 权重相乘
m.weight.data.mul_(mask)
# zero_flag记录这一层是否被全部裁剪
if int(torch.sum(mask)) == 0:
zero_flag = True
这里只是理论上mask掉了,实际运算量没有减少,但是源代码只给到了这里。
mask之后的步骤和我写的前几篇差不多,可以参考剪枝论文二(Filters Pruning)。