pytorch深度学习入门(13)之-模型剪枝

概述

模型剪枝是一种用于神经网络压缩的技术,其主要目的是减少模型的计算复杂性和存储需求,同时尽量保持模型的预测能力。这通常通过删除模型中的冗余信息或减少模型的大小来实现。

剪枝技术主要有以下几种:

重要性剪枝:这种方法首先确定模型中每个权重的重要性,例如可以使用梯度或激活值来判断。然后,删除重要性低的权重,并重新训练模型以调整剩余的权重。
全局剪枝:全局剪枝方法通过对整个网络应用某种全局标准(例如,阈值)来删除权重。这种方法通常在预训练的网络上应用,以减少其大小。
结构化剪枝:这种方法涉及删除网络中的特定层或连接。结构化剪枝通常在预训练的网络上应用,并且可以通过迭代地应用不同的剪枝策略来逐步减小网络的大小。
迭代剪枝:这种方法涉及在每次迭代中应用不同的剪枝策略。例如,可以使用重要性剪枝来删除一些权重,然后使用全局剪枝来进一步减小网络的大小。
混合剪枝:这种方法结合了多种剪枝策略,以实现最佳的压缩效果。例如,可以先使用重要性剪枝来删除一些权重,然后使用全局剪枝来进一步减小网络的大小。
需要注意的是,剪枝技术可能会对模型的性能产生影响,因此需要在压缩模型和保持模型性能之间找到一个平衡点。此外,剪枝后的模型可能需要重新训练以调整剩余的权重。
最先进的深度学习

技术依赖于难以部署的过度参数化模型。相反,生物神经网络已知使用有效的稀疏连接。通过减少模型中的参数数量来确定压缩模型的最佳技术非常重要,这样可以在不牺牲准确性的情况下减少内存、电池和硬件消耗。这反过来又允许您在设备上部署轻量级模型,并通过私有设备上计算来保证隐私。在研究前沿,剪枝用于研究过度参数化和欠参数化网络之间学习动态的差异,研究幸运稀疏子网络和初始化(“彩票”)作为破坏性神经架构搜索技术的作用。

在本教程中,您将学习如何使用torch.nn.utils.prune稀疏神经网络,以及如何扩展它以实现您自己的自定义剪枝技术。

要求
“torch>=1.4.0a0+8e8a5e0”

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

创建模型
在本教程中,我们使用LeCun 等人 (1998) 的LeNet架构。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查模块
让我们检查一下conv1LeNet 模型中的(未剪枝的)层。目前它将包含两个参数weight和bias,并且没有缓冲区。

module = model.conv1
print(list(module.named_parameters()))

输出:

[('weight', Parameter containing:
tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],
          [ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],
          [ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],
          [-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],
          [-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],


        [[[-0.0922, -0.0565, -0.1203,  0.0189, -0.1975],
          [ 0.1806, -0.1699,  0.1544,  0.0333, -0.0649],
          [ 0.1236,  0.0312,  0.1616,  0.0219, -0.0631],
          [ 0.0537, -0.0542,  0.0842,  0.1786,  0.1156],
          [-0.0874,  0.1155,  0.0358,  0.1016, -0.1219]]],


        [[[-0.1980, -0.0773, -0.1534,  0.1641,  0.0576],
          [ 0.0828,  0.0633, -0.0035,  0.1565, -0.1421],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0667,  0.1925, -0.1651, -0.1984]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.1033, -0.1363,  0.1061, -0.0808,  0.1214],
          [-0.0475,  0.1144, -0.1554, -0.1009,  0.0610],
          [ 0.0423, -0.0510,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0675, -0.0709, -0.1935]]],


        [[[-0.1145,  0.0500, -0.0264, -0.1452,  0.0047],
          [-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
          [ 0.1999,  0.0378,  0.0616, -0.1865, -0.1314],
          [-0.0666,  0.0313, -0.1760, -0.0862, -0.1197],
          [ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],


        [[[-0.1167, -0.0685, -0.1579,  0.1677, -0.0397],
          [ 0.1721,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.1988,  0.0572, -0.0437],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418,  0.1033,  0.1615,  0.1822, -0.1586]]]], device='cuda:0',
       requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497,  0.1822, -0.1468], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))

输出:

[]

剪枝模块
要剪枝模块(在本例中是conv1LeNet 架构的层),首先在可用的剪枝技术中选择一种剪枝技术 torch.nn.utils.prune(或 通过子类化实现 您自己的 技术BasePruningMethod)。然后,指定模块以及要在该模块中删除的参数名称。最后,使用所选剪枝技术所需的适当关键字参数,指定剪枝参数。

weight在此示例中,我们将随机剪枝层中指定参数中 30% 的连接conv1。模块作为第一个参数传递给函数;name 使用其字符串标识符来标识该模块内的参数;并 amount指示要剪枝的连接的百分比(如果它是 0 和 1 之间的浮点数),或者要剪枝的连接的绝对数量(如果它是非负整数)。

prune.random_unstructured(module, name="weight", amount=0.3)

输出:

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

剪枝通过weight从参数中删除并将其替换为名为 的新参数weight_orig(即附加"_orig"到初始参数name)来进行。weight_orig存储张量的未剪枝版本。没有被剪枝,所以它bias会保持完整。

print(list(module.named_parameters()))

输出:

[('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497,  0.1822, -0.1468], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],
          [ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],
          [ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],
          [-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],
          [-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],


        [[[-0.0922, -0.0565, -0.1203,  0.0189, -0.1975],
          [ 0.1806, -0.1699,  0.1544,  0.0333, -0.0649],
          [ 0.1236,  0.0312,  0.1616,  0.0219, -0.0631],
          [ 0.0537, -0.0542,  0.0842,  0.1786,  0.1156],
          [-0.0874,  0.1155,  0.0358,  0.1016, -0.1219]]],


        [[[-0.1980, -0.0773, -0.1534,  0.1641,  0.0576],
          [ 0.0828,  0.0633, -0.0035,  0.1565, -0.1421],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0667,  0.1925, -0.1651, -0.1984]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.1033, -0.1363,  0.1061, -0.0808,  0.1214],
          [-0.0475,  0.1144, -0.1554, -0.1009,  0.0610],
          [ 0.0423, -0.0510,  0.1192,  0.1360, -0.1450],
          [-0.1068,  0.1831, -0.0675, -0.0709, -0.1935]]],


        [[[-0.1145,  0.0500, -0.0264, -0.1452,  0.0047],
          [-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
          [ 0.1999,  0.0378,  0.0616, -0.1865, -0.1314],
          [-0.0666,  0.0313, -0.1760, -0.0862, -0.1197],
          [ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],


        [[[-0.1167, -0.0685, -0.1579,  0.1677, -0.0397],
          [ 0.1721,  0.0623, -0.1694,  0.1384, -0.0550],
          [-0.0767, -0.1660, -0.1988,  0.0572, -0.0437],
          [ 0.0779, -0.1641,  0.1485, -0.1468, -0.0345],
          [ 0.0418
  • 19
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码农呆呆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值