目录
1.导包&定义一个简单的网络
#!/user/bin/env python3
# -*- coding: utf-8 -*-
# By PyTanAI.2023.05.07.
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''搭建类LeNet网络'''
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 单通道图像输入,5×5核尺寸
self.conv1 = nn.Conv2d(1, 3, 5)
self.conv2 = nn.Conv2d(3, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
2.获取网络需要剪枝的模块
model = LeNet().to(device=device)
module = model.conv1
print(list(module.named_parameters())) # 6×5×5的weight + 6×1的bias 的参数量
输出:
[('weight', Parameter containing:
tensor([[[[-0.1886, 0.1689, 0.0006, 0.1656, 0.0707],
[ 0.1573, -0.0014, -0.1035, 0.0448, 0.1883],
[ 0.1790, -0.0220, 0.0037, 0.1470, 0.1930],
[-0.0082, -0.0086, -0.0164, -0.1476, -0.1250],
[-0.1283, 0.1223, -0.1508, 0.0952, -0.0481]]],
[[[-0.0183, 0.0274, 0.1387, -0.0060, 0.1295],
[-0.1517, -0.1871, 0.0382, 0.1095, 0.1002],
[-0.0717, -0.1329, 0.0020, 0.1161, 0.1372],
[ 0.0108, -0.1599, -0.0527, 0.1284, 0.0222],
[ 0.1408, 0.1704, 0.0745, -0.1910, 0.1014]]],
[[[ 0.0449, 0.1346, -0.1890, -0.0563, 0.0462],
[-0.0788, 0.0057, -0.0536, -0.0356, -0.0771],
[ 0.1984, 0.1494, 0.1578, 0.0566, -0.1480],
[ 0.0585, 0.0059, 0.0954, 0.1955, 0.0770],
[-0.1679, -0.0318, 0.1034, -0.0277, -0.0480]]]], device='cuda:0',
requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1860, -0.0283, -0.1901], device='cuda:0', requires_grad=True))]
3.模块结构化剪枝(核心)
剪枝一个模块,需要三步:
- step1.在torch.nn.utils.prune中选定一个剪枝方案,或者自定义(通过子类BasePruningMethod)
- step2.指定需要剪枝的模块和对应的名称
- step3.输入对应函数需要的参数
这里,我们根据通道的L2范数,沿着张量的第0轴(第0轴对应卷积层的输出通道,conv1的维数为3×5×5)对weight参数进行结构化剪枝,使用ln_structured
()方法。剪枝比例为33%,dim为0,基于L2范数(n=2)
prune.ln_structured(module, name="weight", amount=0.33, n=2, dim=0)
print(module.weight)
输出:
tensor([-0.1277, 0.0365, -0.0475], device='cuda:0', requires_grad=True))]
tensor([[[[-0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, -0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, -0.0000, 0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, -0.0000]]],
[[[-0.1167, 0.1299, -0.1941, -0.0426, -0.1658],
[ 0.1906, -0.0087, 0.0125, -0.0981, -0.0469],
[ 0.1716, -0.1744, 0.1467, -0.0231, 0.0885],
[-0.0793, -0.0523, -0.0424, -0.1222, 0.0029],
[-0.0573, -0.1058, -0.0319, -0.1033, 0.1366]]],
[[[-0.1145, -0.1666, -0.0811, 0.0916, -0.0191],
[-0.1735, 0.0730, 0.1786, -0.1338, -0.0941],
[ 0.0123, -0.0719, 0.1999, 0.1462, -0.0013],
[ 0.0112, -0.0411, 0.1806, -0.0925, 0.0962],
[-0.0791, 0.1241, 0.1079, 0.0158, -0.1588]]]], device='cuda:0',
grad_fn=<MulBackward0>)
所有相关的张量,包括mask缓冲区和用于计算剪枝张量的原始参数都存储在模型的state_dict中,因此,如果需要,可以很容易地序列化和保存。
print(model.state_dict().keys())
输出:
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
要使修剪永久存在,可以删除weight_trans(orig)和weight_mask重新参数化,并删除forward_pre_hook,可以使用torch.nn.util.prune中的remove函数。注意,这并没有取消修剪,就像它从未发生过一样。它只是简单地使它永久存在,相反,在它的剪枝版本中,通过将参数的权重重新分配给模型参数。
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(model.state_dict().keys())
输出:
[('bias', Parameter containing:
tensor([ 0.1144, -0.1641, 0.0962], requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0142, 0.1698, -0.0730, -0.0358, -0.1309],
[ 0.1520, 0.1900, -0.0843, 0.0950, 0.1674],
[-0.1724, 0.1453, -0.1764, 0.0345, -0.1767],
[ 0.0727, 0.1170, 0.1585, -0.0713, -0.0158],
[ 0.1485, -0.0270, -0.0164, 0.0889, 0.1170]]],
[[[ 0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[-0.1681, 0.1801, -0.0567, -0.0366, 0.0085],
[ 0.0495, 0.0320, -0.0127, -0.1761, -0.0948],
[ 0.1340, 0.1103, 0.1332, -0.1911, -0.1225],
[ 0.0781, -0.0920, -0.1759, 0.0977, 0.0030],
[-0.0436, -0.1694, -0.0094, -0.0553, -0.0591]]]], requires_grad=True))]
[]
odict_keys(['conv1.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
使用remove
函数后
prune.remove(module, 'weight')
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(model.state_dict().keys())
输出:
[('bias', Parameter containing:
tensor([ 0.1144, -0.1641, 0.0962], requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0142, 0.1698, -0.0730, -0.0358, -0.1309],
[ 0.1520, 0.1900, -0.0843, 0.0950, 0.1674],
[-0.1724, 0.1453, -0.1764, 0.0345, -0.1767],
[ 0.0727, 0.1170, 0.1585, -0.0713, -0.0158],
[ 0.1485, -0.0270, -0.0164, 0.0889, 0.1170]]],
[[[ 0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[-0.1681, 0.1801, -0.0567, -0.0366, 0.0085],
[ 0.0495, 0.0320, -0.0127, -0.1761, -0.0948],
[ 0.1340, 0.1103, 0.1332, -0.1911, -0.1225],
[ 0.0781, -0.0920, -0.1759, 0.0977, 0.0030],
[-0.0436, -0.1694, -0.0094, -0.0553, -0.0591]]]], requires_grad=True))]
[]
odict_keys(['conv1.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
可以发现,缓冲区中保存的mask参数没了
4.总结
本示例首先搭建了一个类LeNet网络模型,为了进行结构化剪枝,我们选取了LeNet的conv1模块,该模块参数包含为3×5×5的weight卷积核参数和3×1的bias参数,通过示例,我们利用torch.nn.prune中的ln_structured剪枝方法,实现了对weight的3个通道中其中一个通道进行了L2 norm结构化剪枝。
本文用到的核心函数方法:
- model.state_dict().keys(),模型所有的张量参数,还包括了mask缓冲区的参数
- prune.ln_structured(),Lx Norm结构化剪枝
- prune.remove(),从模块中移除修剪重参数化。已修剪的名为name的参数将保持永久修剪,而名为name+_trans (orig)的参数将从参数列表中删除。类似地,名为name+_mask的缓冲区将从缓冲区中删除。
参考资料:Pytorch官方剪枝教程