毕业论文需要,所以近期在学习剪枝的东西,简单学了一下基础的东西,担心以后忘记,便进行简单的记录,以下是剪枝的主要流程:
(1)模型的稀疏训练:这是剪枝之前非常重要的一步,在BN层中添加L1正则项可以使模型在训练过程中带来稀疏性。
def updateBN():
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
# sign是符号函数
m.weight.grad.data.add_(args.s*torch.sign(m.weight.data)) # L1
(2)模型的剪枝:
简单的用四步遍历来进行流程上的讲解吧:
①第一步遍历,记录原始模型BN层的通道总数;(便于后面计算模型要裁剪的总数);
②第二步遍历,先用上一步的通道数来构建一个空列表bn,用来存放所有的BN层的权重参数,之后对这些权重参数进行排序,然后根据剪枝率计算要裁剪的阈值;
③第三次遍历,利用上一次遍历计算出来的阈值来构建一个BN层的mask(大于阈值为1,小于阈值为0),并用此来获取剪枝后模型的配置(cfg)。此时只要将原始权重 * mask,这样模型就是增加了mask的模型了(需要注意的是,效果上和剪枝之后是一样的,但是参数量并没有变少,只是设置为0了);