pytorch框架下实现top-k剪枝
这篇博客,以MNIST数据集为例,对LSTM的权重矩阵实现top-k剪枝(7,2),介绍了如何在pytorch框架下实现top-k剪枝。
一. top-k剪枝
-
LSTM常被应用自然语言处理(NLP)相关的应用,由于引入了memory cell和gate unit,其含有大量参数,即使被剪枝90%的参数,仍然不会给模型带来太大的精度损失,较多的冗余参数带来很多不必要的资源消耗,因此需要被剪枝。随机剪枝产生的稀疏矩阵,需要额外的资源去存储位置信息,因此,规则剪枝更占优势。
-
这篇博客采用MNIST数据集,搭建了一个含有双层LSTM,线性层的RNN模型,其中LSTM的输入,隐藏层输出维度均为28,采用的top-k为,lstm的权重矩阵的每一行,7个分为一组,每组只保留最大的2个,其余的均为0。top-k剪枝的文献
-
这样剪枝获得的权重矩阵每一行数量都相等,且保留下来的权重的位置信息,只需要3个2进制数就可以表示,符合FPGA运算时对负载平衡和减少参数的需求。
二. 生成掩模(mask)矩阵
- Pytorch剪枝时,需要一个掩模矩阵,该矩阵和待剪枝的矩阵维度大小相等,只包含1,0两个数值,1表示该位置的数据保留,0表示该位置的数据被剪枝;
可以使用如下代码,查看模型都含有哪些权重矩阵:
for name, _ in model.named_parameters():
print(name)
- 我定义的rnn模型,lstm(双层)含有的权重参数为rnn.lstm.weight_ih_l0,rnn.lstm.weight_hh_l0, rnn.lstm.weight_ih_l1, rnn.lstm.weight_hh_l1.
矩阵每行含有28个参数,将其分为4组,每组7个元素,只保留最大的2个:
def topk(para, k):
c = torch.zeros(para.size()[0], para.size()[1],dtype = torch.int) #初始化一个和权值矩阵相同大小的掩膜矩阵
l = int(para.size()[1]/7) #将每行的每7个权值分为一组,l为分组的数量
parameter = torch.abs(para) #将权值矩阵取绝对值
_, b = torch.topk(parameter[:,:7], k, 1, largest = True) #b为0~6之间的k个数,表示该组最大的前k个权值的位置
for i in range(1,l):
_, b1 = torch.topk(parameter[:,i*7:(i+1)*7], k, 1, largest = True) #遍历每一组最大的前k个值的位置
b1 = b1 + i * 7 #得到每一行中保留的权值位置信息的绝对位值
b = torch.cat((b,b1),dim=1) #将每一段拼接起来
for j in range(c.size()[0]):
c[j, b[j, :]] = 1 #将c中,b中位置信息的对应的位置,置1(保留),其他部分为0
return c
c1,c2,c3,c4是根据四个权重矩阵生成的四个掩模矩阵(我定义的双层LSTM有四个权重矩阵),生成的掩模矩阵元素均为0或1
c1 = topk(rnn.lstm.weight_ih_l0.data, 2)
c2 = topk(rnn.lstm.weight_hh_l0.data, 2)
c3 = topk(rnn.lstm.weight_ih_l1.data, 2)
c4 = topk(rnn.lstm.weight_hh_l1.data, 2)
生成的掩模矩阵如图所示:
三. 定义剪枝函数
pytorch提供的自定义剪枝的模板,这里分别将c1,c2,c3,c4作为掩模矩阵,这段代码的意思就是,rnn模型中的lstm层的权重矩阵weight_ih_l0对应掩模矩阵c1, c1元素为1的位置,保留;c1为0的,weight_ih_l0对应的位置被剪枝掉,以此类推;
class FooBarPruningMethod1(prune.BasePruningMethod):
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = c1
return mask
class FooBarPruningMethod2(prune.BasePruningMethod):
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = c2
return mask
class FooBarPruningMethod3(prune.BasePruningMethod):
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = c3
return mask
class FooBarPruningMethod4(prune.BasePruningMethod):
"""Prune every other entry in a tensor
"""
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = c4
return mask
def foobar_unstructured(model):
FooBarPruningMethod1.apply(model.lstm, 'weight_ih_l0')
FooBarPruningMethod2.apply(model.lstm, 'weight_hh_l0')
FooBarPruningMethod3.apply(model.lstm, 'weight_ih_l1')
FooBarPruningMethod3.apply(model.lstm, 'weight_hh_l1')
return model
rnn = foobar_unstructured(rnn) #对预训练完成的模型进行top-k剪枝
剪枝过后再训练,会发现,剪枝后的训练速度,明显快于剪枝前。
剪枝后的矩阵如图所示:
总结
这篇博客以MNIST数据集为例,搭建了一个含有双层LSTM,和FC层的模型,预训练后对其进行top-k剪枝,详细介绍了pytorch框架下的top-k剪枝过程;
- 完整代码下载:pytorch-topk
参考文献
- top-k剪枝的文献:E-LSTM: An Efficient Hardware Architecture for Long Short-Term Memory
- pytorch官方剪枝教程:pytorch剪枝