vgg模型剪枝实例

vgg 模型构建

vgg模型是较为经典的图像分类算法模型。

新建vgg.py 文件,将一下代码复制到里面。用作生成vgg 模型。

import torch.nn as nn
import torch
from torch.autograd import Variable
import math

# official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}
class VGG(nn.Module):
    def __init__(self, dataset = 'cifar10', num_classes=10, init_weights=True,cfg = None):
        super(VGG, self).__init__()
        if cfg is None:
            cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
        print(cfg)
        self.feature = self.make_layers(cfg,True)
        if (dataset == 'cifar100'):
            num_classes = 100
        elif dataset == 'cifar10':
            num_classes = 10
        self.classifier = nn.Linear(cfg[-1],num_classes)
        if init_weights:  # 是否进行权重初始化
            self._initialize_weights()

    def make_layers(self,cfg,batch_norm = False):
        layers = []
        in_channels = 3
        for v in  cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2,stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels,v,kernel_size=3,padding=1,bias= False)
                if batch_norm:
                    layers += [conv2d,nn.BatchNorm2d(v),nn.ReLU(inplace= True)]
                else:
                    layers+= [conv2d,nn.ReLU(inplace= True)]
                in_channels = v
        return nn.Sequential(*layers)

    # 正向传播过程
    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.feature(x)  # 输入到特征提取网络

        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0),-1)
        x = self.classifier(x)  # 输入到分类网络中,得到输出
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)



cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M',10],
}


# 实例化配置模型
def vgg(model_name="vgg19",cfg = '' ,**kwargs):
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    if cfg == '':
        cfg = None
    print(cfg)
    model = VGG(cfg = cfg)  # 可以传递任意数量的实参,以字典的形式导入
    return model

vgg 模型训练

新建main.py 文件,将一下代码复制到里面。运行,即可自动下载数据集cifar 10 ,并且训练得到vgg 19 的模型权重文件。

from __future__ import print_function
import os
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
from torch.autograd import Variable
import torch

from vgg import vgg
import shutil

# 1 训练,加入l1 正则化, sr --s 0.0001
# 2 执行剪枝操作 -- model model_best.pth.tar --save pruned.pth.tar --percent 0.7
# 3 再次进行微调操作 -refine pruned.pth.tar --epochs 40

parser = argparse.ArgumentParser(description='pytorch slimming CIFAR training')
parser.add_argument('--dataset',type=str,default='cifar10',help='training dataset')
parser.add_argument('--sparsity-regulariztion','-sr',dest = 'sr',action='store_true',help='')
parser.add_argument('--s',type=float,default=0.0001,help='training dataset')
parser.add_argument('--refine',type=str,default='',metavar='PATH',help='')
parser.add_argument('--batch_size',type=int,default=100,metavar='N',help='')
parser.add_argument('--test_batch_size',type=int,default=100,metavar='N',help='')
parser.add_argument('--epochs',type=int,default=5,metavar='N',help='')
parser.add_argument('--start_epoch',type=int,default=0,metavar='N',help='')
parser.add_argument('--lr',type=float,default=0.1,metavar='LR',help=' ')
parser.add_argument('--momentum',type=float,default=0.9,metavar='M',help='')
parser.add_argument('--weight_decay','--wd',type=float,default=1e-4,metavar='W',help='')
parser.add_argument('--resume',type=str,default='',metavar='PATH',help='')
parser.add_argument('--no-cuda',action='store_true',default=False,help='')
parser.add_argument('--seed',type=int ,default=1,metavar='S',help='')
parser.add_argument('--log-interval',type=int ,default=100,metavar='N',help='')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

kwargs = {'num_workers': 0 ,'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data',train= True,download= True,transform= transforms.Compose([
        transforms.Pad(4),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])),
    batch_size = args.batch_size,shuffle = True,**kwargs
)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data',train= False,download= True,transform= transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])),
    batch_size = args.test_batch_size,shuffle = True,**kwargs
)

if args.refine:
    checkpoint = torch.load(args.refine)
    model = vgg(cfg = checkpoint['cfg'])
    # model = vgg()
    model.cuda()
    model.load_state_dict(checkpoint['state_dict'])

else:
    model = vgg()
if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(),lr = args.lr,momentum= args.momentum ,weight_decay= args.weight_decay)
if args.resume:
    if os.path.isfile(args.resume):
        print("loading check point '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("loading check point '{}' (epoch {})prec1:{:f}".format(args.resume,checkpoint['epoch'],best_prec1))
    else:
        print("no checkpoingt fount at '{}'".format(args.resume))

def updateBN():
    for m in model.modules():
        if isinstance(m,nn.BatchNorm2d):
            m.weight.grad.data.add_(args.s*torch.sigh(m.weight.data))
def train(epoch):
    model.train()
    for batch_idx,(data,target) in enumerate(train_loader):
        if args.cuda:
            data,target = data.cuda(),target.cuda()
        data,target = Variable(data),Variable(target)
        optimizer.zero_grad()
        # print(data)
        output  = model(data)
        loss = F.cross_entropy(output,target)
        loss.backward()
        if args.sr:
            updateBN()
        optimizer.step()
        if batch_idx % args.log_interval == 0 :
            print('train epoch :{}[{}/{} ({:.1f}%)]\t loss:{:.6f}'.format(epoch,batch_idx * len(data)
            ,len(train_loader.dataset),100. * batch_idx / len(train_loader),loss.item()))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data,target in test_loader:
        if args.cuda:
            data ,target = data.cuda(),target.cuda()
        data,target = Variable(data),Variable(target)
        output = model(data)
        test_loss += F.cross_entropy(output,target,size_average=False).item()
        pred = output.data.max(1,keepdim = True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\n Test set : average loss : {:.4f},Accuracy : {}/{} ({:.1f}%)\n'.format(
        test_loss,correct,len(test_loader.dataset),100. * correct / len(test_loader.dataset)
    ))
    return correct / float(len(test_loader.dataset))

def save_checkpoint (state,is_best,filename = 'checkpoint.pth.tar'):
    torch.save(state,filename)
    if is_best:
        shutil.copyfile(filename,'model_best.pth.tar')

best_prec1 = 0
for epoch in range(args.start_epoch,args.epochs):
    if epoch in [args.epochs*0.5 ,args.epochs*0.75]:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= 0.1
    train(epoch)
    prec1 = test()
    is_best = prec1 > best_prec1
    best_prec1 = max(prec1,best_prec1)
    save_checkpoint({
        'epoch':epoch + 1,
        'state_dict': model.state_dict(),
        'best_prec1' : best_prec1,
        'optimizer' : optimizer.state_dict(),
    },is_best)


print ("done")


vgg 模型剪枝

新建 prune.py 将一下代码复制到里面。

调整 percent 参数 来 调整参数剪枝比例。

运行文件,可生成剪枝后的模型config 以及模型权重文件。

import os
import argparse
import torch.nn as nn
from torchvision import datasets,transforms
from torch.autograd import Variable
import torch
import numpy as np
from vgg import vgg
import shutil
parser = argparse.ArgumentParser(description='pytorch slimming CIFAR training')
parser.add_argument('--dataset',type=str,default='cifar10',help='training dataset')
parser.add_argument('--test_batch_size',type=int,default=100,metavar='N',help='')
parser.add_argument('--no-cuda',action='store_true',default=False,help='')
parser.add_argument('--model',type=str,default='model_best.pth.tar',metavar='PATH',help='')
parser.add_argument('--save',type=str,default='pruned.pth.tar',metavar='PATH',help='')
parser.add_argument('--percent',type=float,default=0.7,help='')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

model = vgg()
if args.cuda:
    model.cuda()
if args.model:
    if os.path.isfile((args.model)):
        print("loading check point '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer'])
        print("loading check point '{}' (epoch {})prec1:{:f}".format(args.model,checkpoint['epoch'],best_prec1))
    else:
        print("no checkpoingt fount at '{}'".format(args.resume))

print(model)
total = 0
for m in model.modules():
    if isinstance(m,nn.BatchNorm2d):
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m,nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size

y,i = torch.sort(bn)
thre_index = int (total * args.percent)
thre = y[thre_index]

pruned = 0
cfg =[]
cfg_mask = []
for k,m in enumerate(model.modules()):
    if isinstance(m,nn.BatchNorm2d):
        weight_copy = m.weight.data.clone()
        mask = weight_copy.abs().gt(thre).float().cuda()
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print("layer index :{:d} \t total channel :{:d} \t remaining channel :{:d}"
              .format(k,mask.shape[0],int(torch.sum(mask))))
    elif isinstance(m,nn.MaxPool2d):
        cfg.append('M')

pruned_radio = pruned / total
print('pre processing successful!')

def test():
    kwargs = {'num_workers' : 0,'pin_memory' :True} if args.cuda else {}
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs
    )
    model.eval()
    correct = 0
    for data,target in test_loader:
        if args.cuda:
            data,target = data.cuda(),target.cuda()
        data, target = Variable(data,volatile = True), Variable(target)
        output = model(data)
        pred = output.data.max(1,keepdim = True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    print('\n Test set : accuracy : {}/{} ({:.1f}%)\n'.format(
        correct,len(test_loader.dataset),100.*correct /len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

test()



print(cfg)
newmodel = vgg(cfg = cfg)
newmodel.cuda()
print(newmodel)
layer_id_in_cfg = 0
start_musk = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
for [m0,m1] in zip(model.modules(),newmodel.modules()):
    if isinstance(m0,nn.BatchNorm2d):
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        m1.weight.data = m0.weight.data[idx1].clone()
        m1.bias.data = m0.bias.data[idx1].clone()
        m1.running_mean = m0.running_mean[idx1].clone()
        m1.running_var = m0.running_var[idx1].clone()
        layer_id_in_cfg += 1
        start_musk = end_mask.clone()
        if layer_id_in_cfg <len(cfg_mask):
            end_mask = cfg_mask[layer_id_in_cfg]

    elif isinstance(m0,nn.Conv2d):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_musk.cpu().numpy())))
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
        print('In shape : {:d} out shape : {:d}'.format(idx0.shape[0],idx1.shape[0]))
        w = m0.weight.data[:,idx0,:,:].clone()
        w = w [idx1,:,:,:].clone()
        m1.weight.data = w.clone()

    elif isinstance(m0,nn.Linear):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_musk.cpu().numpy())))
        m1.weight.data = m0.weight.data[:,idx0].clone()

torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, args.save)
print(newmodel)
model = newmodel
test()




模型微调

返回main.py 文件。将refine 参数设置为剪枝后的模型权重文件路径。再次进行微调。即可得到剪枝后的vgg 模型。


 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值