最近在学习模型压缩中的剪枝
但是对于怎么实现剪枝不太了解
于是查找了别人的代码,并在过程中加入自己的注释理解
这次学习的是在resnet18训练好的cifar-10 下进行的剪枝
代码源于
https://github.com/kentaroy47/Deep-Compression.Pytorch
以下是prune模块
# -*- coding: utf-8 -*-
'''Deep Compression with PyTorch.'''
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from models import *
from utils import progress_bar
import numpy as np
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Pruning')
parser.add_argument('--loadfile', '-l', default="checkpoint/res18.t7",dest='loadfile')
parser.add_argument('--prune', '-p', default=0.5, dest='prune', help='Parameters to be pruned')
parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--net', default='res18')
args = parser.parse_args()
prune = float(args.prune) #prune = 0.5 剪去50%
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=0)
# Model
print('==> Building model..')
if args.net=='res18':
net = ResNet18()
elif args.net=='vgg':
net = VGG('VGG19')
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
# Load weights from checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isfile(args.loadfile), 'Error: no checkpoint directory found!'
checkpoint = torch.load(args.loadfile) #dict
net.load_state_dict(checkpoint['net'])
#dict_keys(['acc', 'epoch', 'net', 'address', 'mask']), len(checkpoint) = 5
print(checkpoint.values())
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def prune_weights(torchweights):
weights=np.abs(torchweights.cpu().numpy());
weightshape=weights.shape #返回一个整型数字的元组,元组中的每个元素表示相应的数组每一维的长度
rankedweights=weights.reshape(weights.size).argsort()
#.reshape(weightshape) 这里应该是将weights化成一维形式,argsort()函数是将x中的元素从小到大排列,提取其对应的index(索引号)
num = weights.size
prune_num = int(np.round(num*prune))
print('prune_num:',prune_num)
count=0
masks = np.zeros_like(rankedweights)
for n, rankedweight in enumerate(rankedweights): #n是idx, rankweight 是idx对应的权重
if rankedweight > prune_num:
masks[n]=1
else: count+=1
# if n<15:
# print("n, rankedweight:",n,'\t',rankedweight)
# print('masks:',masks)
print("total weights:", num)
print("weights pruned:",count)
masks=masks.reshape(weightshape) #转化成只有1 and 0 的矩阵形式再与weights相乘即可将某些权重清零
weights=masks*weights
return torch.from_numpy(weights).cuda(), masks
'''for example
pruning layer: module.layer1.0.conv2.weight
prune_num: 18432
n, rankedweight: 0 14054
n, rankedweight: 1 1747
n, rankedweight: 2 31774
n, rankedweight: 3 16140
n, rankedweight: 4 1811
n, rankedweight: 5 35556
n, rankedweight: 6 16134
n, rankedweight: 7 1784
n, rankedweight: 8 7769
n, rankedweight: 9 1896
n, rankedweight: 10 16356
n, rankedweight: 11 2028
n, rankedweight: 12 1808
n, rankedweight: 13 30484
n, rankedweight: 14 30050
masks: [0 0 1 ... 1 0 1]
total weights: 36864
weights pruned: 18433
###############################
64
'''
# print("rankedweights:",rankedweights)
# prune weights
# The pruned weight location is saved in the addressbook and maskbook.
# These will be used during training to keep the weights zero.
addressbook=[]
maskbook=[]
#items把字典的每一对key和value组成数组后以列表的形式返回
for k, v in net.state_dict().items():
if "conv2" in k:
addressbook.append(k)
# k = module.layer*.*.conv2.weight 字典名称
print("pruning layer:",k)
# print('\t', v,v.size(1),'\t',v.size(2))
weights=v #矩阵 512 * 3
weights, masks = prune_weights(weights)
# print(len(masks)) #len = 64, 128, 256, 512
maskbook.append(masks)
# print(weights)
checkpoint['net'][k] = weights
checkpoint['address'] = addressbook
checkpoint['mask'] = maskbook
net.load_state_dict(checkpoint['net'])
# Training
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=5e-4)
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
# mask pruned weights
checkpoint['net']=net.state_dict()
# print("zeroing..")
# print(np.count_nonzero(checkpoint['net'][addressbook[0]].cpu().numpy()))
# #count_nonzero 数module.layer1.0.conv2.weight 里面weight != 0 的个数
for address, mask in zip(addressbook, maskbook):
print(address)
checkpoint['net'][address] = torch.from_numpy(checkpoint['net'][address].cpu().numpy() * mask)
print(checkpoint['net'][address])
print(np.count_nonzero(checkpoint['net'][addressbook[0]].cpu().numpy()))
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
# Save checkpoint.
acc = 100.*correct/total
if acc > best_acc:
print('Saving..')
state = {
'net': net.state_dict(),
'acc': acc,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/pruned-'+args.net+'-ckpt.t7')
best_acc = acc
if __name__ == '__main__':
for epoch in range(start_epoch, start_epoch+20):
train(epoch)
test(epoch)
with open("prune-results-"+str(prune)+'-'+str(args.net)+".txt", "a") as f:
f.write(str(epoch)+"\n")
f.write(str(best_acc)+"\n")
自己还有很多不太懂的地方,记录一下学习经历,day day up