pytorch自带的模型剪枝工具prune的使用

torch.nn.utils.prune可以对模型进行剪枝,官方指导如下:

https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

直接上代码

首先建立模型网络:

import torch
import torch.nn as nn
from torchsummary import summary
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv3 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(in_features=16 * 16 * 24, out_features=num_classes)
    def forward(self, input):
        output = self.conv1(input)
        output = nn.ReLU()(output)
        output = self.conv2(output)
        output = nn.ReLU()(output)
        output = self.pool(output)
        output = self.conv3(output)
        output = nn.ReLU()(output)
        output = self.conv4(output)
        output = nn.ReLU()(output)
        output = output.view(-1, 16 * 16 * 24)
        output = self.fc(output)
        return output
model = SimpleNet().to(device=device)

看一下模型的 summary

summary(model, input_size=(3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 12, 512, 512]             336
            Conv2d-2         [-1, 12, 512, 512]           1,308
         MaxPool2d-3         [-1, 12, 256, 256]               0
            Conv2d-4         [-1, 24, 256, 256]           2,616
            Conv2d-5         [-1, 24, 256, 256]           5,208
            Linear-6                   [-1, 10]          61,450
================================================================
Total params: 70,918
Trainable params: 70,918
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 78.00
Params size (MB): 0.27
Estimated Total Size (MB): 81.27
----------------------------------------------------------------

打印一下模型结构各层的名称:

print(model.state_dict().keys())

结果:

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'conv4.weight', 'conv4.bias', 'fc.weight', 'fc.bias'])

接下来 对其进行剪枝操作:

import torch.nn.utils.prune as prune
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.conv4, 'weight'),
    (model.fc, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

执行结束后,再打印一下:

print(model.state_dict().keys())

结果:

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'conv3.weight', 'conv3.bias', 'conv4.bias', 'conv4.weight_orig', 'conv4.weight_mask', 'fc.bias', 'fc.weight_orig', 'fc.weight_mask'])

我们发现剪枝结束后conv*.weight已经 消失了,出现了两个weight:weight_orig和weight_mask

其实weight_orig就是剪枝以前的weight,而weight_mask里面 只是0和1,0代表的是被剪枝的

打印一下:

print(model.state_dict()['conv1.weight_orig'])

tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [0., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 0.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]], device='cuda:0')
prune.remove(module, 

剪枝后,其实还是比较鸡肋的,因为只是剪之后的神经元相当于置零了,模型大小不会变,下面打印一下,有点dropout的意思了

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 12, 512, 512]             336
            Conv2d-2         [-1, 12, 512, 512]           1,308
         MaxPool2d-3         [-1, 12, 256, 256]               0
            Conv2d-4         [-1, 24, 256, 256]           2,616
            Conv2d-5         [-1, 24, 256, 256]           5,208
            Linear-6                   [-1, 10]          61,450
================================================================
Total params: 70,918
Trainable params: 70,918
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 3.00
Forward/backward pass size (MB): 78.00
Params size (MB): 0.27
Estimated Total Size (MB): 81.27
----------------------------------------------------------------

是不是和剪枝之前实际上是一样的,可能会减少运算,但是似乎好像知乎大神提到的被证明运算也没啥提升

  • 7
    点赞
  • 60
    收藏
    觉得还不错? 一键收藏
  • 19
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值