ICCV 2017——Learning Efficient Convolutional Networks Through Network Slimming(模型剪枝)

ICCV 2017——Learning Efficient Convolutional Networks Through Network Slimming(模型剪枝)

在这里插入图片描述

论文地址:https://arxiv.org/pdf/1708.06519.pdf
源码地址:https://github.com/liuzhuang13/slimming

1. 论文概述

本篇论文提出了一种在channel-level上的裁剪策略,通过计算网络中feature map的重要性并对其进行裁剪,实现模型的压缩和加速。

1.1 论文动机

  1. 在卷积操作中,并非所有的特征图都是有用的,因此我们希望裁剪掉没有作用的特征图。
  2. 通过某种策略体现特征图权重参数的重要程度。

1.2 三种level的裁剪策略对比

  1. weigt-level:比较常见的是对权重进行置0操作,用于非结构化裁剪中,具备较大的压缩率,但在工业应用中通常需要特定的硬件和库来进行稀疏运算
  2. layer-level:对完整的卷积层进行裁剪,该方式灵活性差。且一般建立在网络足够深的模型中。
  3. channel-level:比较灵活,一般适用于各种CNN架构中。

2. 剪枝策略

本篇论文的主要思想是通过计算Batch Normalization层中的γ参数得到每个channel的重要程度,再通过L1正则化操作对γ参数进行稀疏,最后对原始网络中的每个channel的重要程度进行排序,并人为设置一个阈值(剪枝程度),裁剪掉低于该阈值对应的所有feature map。

接下来我们解释以下BN的作用(γ参数是什么,为什么这么做),L1范式的稀疏计算,以及总体网络流程。

2.1 BN操作的本质作用

首先我们先明确一个常规套路:Conv+ BN + Activation Function,我总结了BN的三个作用:

  • 当初期训练时,权重一般是随机初始化的,因此这时候的权重并没有很强的学习意义,我们需要通过BN策略将其拉到一个正确的学习方向上,也就是将偏离的分布拉回来。
  • 在训练过程中,将数据分布规范至标准的正态分布。
  • BN是介于卷积和激活函数之间的一个过度,我们希望将数据有约束的分布在非线性函数(激活函数)的线性区域,能够有效的时激活函数更加敏感,训练更快。(如下图所示)
    在这里插入图片描述
    通过上述的总结,我们可以引出BN操作的公式,即对输入数据进行一个标准化操作,再引入两个仿射变换参数进行分布的拉伸和偏移。值得注意的时这里的仿射变换参数是可学习的,而其中的γ参数就是本文的核心。为什么γ参数可以作为每个feature map重要性的表示呢?我们可以举一个简单的例子,模型其实希望输入激活函数的值更偏向线性区域,使其梯度更大,训练更快,因此需要进行一定程度的仿射变换,而当γ参数较大时,说明这张feature map中的权重是我们非常想要的,需要较大程度的变换,保证其能够更好的输入激活函数中。总结来说,γ参数越大,对应的feature map越重要。(注:每个channel对应一个γ参数,可以对照下面公式,γ参数的获取是基于channel参数的均值和方差获得的)
    z ^ = z in  − μ B σ B 2 + ϵ ; z out  = γ z ^ + β \hat{z}=\frac{z_{\text {in }}-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2+\epsilon}} ; \quad z_{\text {out }}=\gamma \hat{z}+\beta z^=σB2+ϵ zin μB;zout =γz^+β

2.2 L1正则化稀疏化

在这里插入图片描述
Figure 4的第一张图展示了在不做任何处理情况下γ参数(scaling factors)的分布情况,可见在对BN层不做任何约束的情况下,γ参数呈现分散分布。这样并不能很好的体现出feature map的重要程度,导致剪枝效果不佳,因此作者采用了L1正则化的方式来对γ参数分布进行稀疏化操作。具体如以下公式所示,其中λ参数为稀疏惩罚因子。当λ=0.0001时稀疏效果最佳,大部分参数趋向于0,说明该特征图重要程度低可以进行裁剪。
L = ∑ ( x , y ) l ( f ( x , W ) , y ) + λ ∑ γ ∈ Γ g ( γ ) L=\sum_{(x, y)} l(f(x, W), y)+\lambda \sum_{\gamma \in \Gamma} g(\gamma) L=(x,y)l(f(x,W),y)+λγΓg(γ)

2.3 剪枝流程

在这里插入图片描述
该图展示了本文的剪枝流程,核心围绕着如何计算出channel scaling factors进行。而计算 channel scaling factors的方案就是通过上述BN+L1正则化操作实现的。最后,我们对channel scaling factors由大到小进行排序,并保留人为设置阈值内的factors所对应的channel。
在这里插入图片描述
当我们对channel进行裁剪后,模型被压缩了,但是如果要想得到好的性能需要进行再训练操作进行网络参数的微调。因此本文整体操作流程为:训练,剪枝,再训练。注意的是这里的训练是指训练原始网络,再训练是训练剪枝后的网络。

2.4 实验

在这里插入图片描述
通过表格可以看到,该方案对VGG网络的紧致化计算效果很好,不仅压缩了模型还对网络进行了加速。且在DenseNet和ResNet上也有着一定的优化,但是文章对于这两个网络的实验主要建立在CIFAR数据集上,由于该数据集较简单,因此并不能很好的展示优化效果。文章基于mnist和ImageNet数据集对VGG网络进行了测试,效果还是不错的,但是性能的优化主要都是针对VGG网络的。我认为该方案对于残差结构还是有一定的限制,因为我们无法保障相加的两个特征之间在剪枝后的channel相等。

3. 代码实现

基于CIFAR 10跑了一下VGG网络(epoch=10),并获得以下结果。

状态准确率模型大小
训练87%152M
剪枝10%6M
再训练89%12M

每层的剪枝情况:
在这里插入图片描述
代码(main.py;prune.py;vgg.py)

# main.py

from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

from vgg import vgg
import shutil

# Training settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR training')
parser.add_argument('--dataset', type=str, default='cifar10',
                    help='training dataset (default: cifar10)')
parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',
                    help='train with channel sparsity regularization')
parser.add_argument('--s', type=float, default=0.0001,
                    help='scale sparse rate (default: 0.0001)')
parser.add_argument('--refine', default='', type=str, metavar='PATH',
                    help='refine from prune model')
parser.add_argument('--batch-size', type=int, default=100, metavar='N',
                    help='input batch size for training (default: 100)')
parser.add_argument('--test-batch-size', type=int, default=100, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 160)')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate (default: 0.1)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)


kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.Pad(4),
                       transforms.RandomCrop(32),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

if args.refine:
    checkpoint = torch.load(args.refine)
    model = vgg(cfg=checkpoint['cfg'])
    model.cuda()
    model.load_state_dict(checkpoint['state_dict'])
else:
    model = vgg()
if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(args.resume, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

# 正则化处理将其数据稀疏化
def updateBN():
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.weight.grad.data.add_(args.s*torch.sign(m.weight.data))  # L1


def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        if args.sr:
            updateBN()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return 100. *correct / float(len(test_loader.dataset))


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

if __name__ == "__main__":
    best_prec1 = 0.
    for epoch in range(args.start_epoch, args.epochs):
        if epoch in [args.epochs*0.5, args.epochs*0.75]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
        train(epoch)
        prec1 = test()
        print("--------------------------------------",prec1)
        is_best = prec1 > best_prec1
        print("--------------------------------------",is_best)
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, is_best)

# prune.py

import os
import argparse
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms

from vgg import vgg
import numpy as np

# 剪枝设置
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
# 数据集
parser.add_argument('--dataset', type=str, default='cifar10',
                    help='training dataset (default: cifar10)')
# 测试batchsize
parser.add_argument('--test-batch-size', type=int, default=10, metavar='N',
                    help='input batch size for testing (default: 1000)')
# cuda
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
# 剪枝率
parser.add_argument('--percent', type=float, default=0.5,
                    help='scale sparse rate (default: 0.5)')
# 剪枝模型
parser.add_argument('--model', default='', type=str, metavar='PATH',
                    help='path to raw trained model (default: none)')
# 保存模型
parser.add_argument('--save', default='', type=str, metavar='PATH',
                    help='path to save prune model (default: none)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

model = vgg()

if args.cuda:
    model.cuda()

if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
              .format(args.model, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

print(model)
total = 0
# model.modules()用于迭代遍历没一个网络层
for m in model.modules():
    # 判断该层若为BN层则返回True,主要计算特征图总数
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]
        # print("BN_num: ", total)

# 通过排列剪枝因子大小进行排序计算剪枝率
bn = torch.zeros(total)
index = 0
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        size = m.weight.data.shape[0]
        bn[index:(index+size)] = m.weight.data.abs().clone()
        index += size

y, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]

# 求mask和需要保留的特征图
# 这一步还是做非结构化的剪枝,不改变网络结构,只是将权重置为0
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
    if isinstance(m, nn.BatchNorm2d):
        weight_copy = m.weight.data.abs().clone()
        # 大于阈值置为1,小于阈值置为0
        mask = weight_copy.gt(thre.cuda()).float().cuda()
        # 计算要剪枝的特征图总数
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        # 将之前训练的模型每层的数据和mask相乘,激活有用的,抑制没用的
        m.weight.data.mul_(mask)
        m.bias.data.mul_(mask)
        # 计算每层有用的特征图数量
        cfg.append(int(torch.sum(mask)))
        # 计算特征每层的mask
        cfg_mask.append(mask.clone())

        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))
    elif isinstance(m, nn.MaxPool2d):
        cfg.append('M')

pruned_ratio = pruned/total

print("pruned_ratio: ", pruned_ratio)
print('Pre-processing Successful!')


# simple test model after Pre-processing prune (simple set BN scales to zeros)
def test():
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    # 加载测试集
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])),
        batch_size=args.test_batch_size, shuffle=True, **kwargs)
    # 关闭训练模式,打开评估模式。这时候会将网络中的优化操作如Dropout和BN关闭,使模型不会造成偏移
    model.eval()
    correct = 0
    for data, target in test_loader:
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
    print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
        correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))

if __name__ == "__main__":
    # 剪枝评估,一般效果并不好。所以需要进行再训练
    test()
    
    # Make real prune
    print(cfg)
    newmodel = vgg(cfg=cfg)
    newmodel.cuda()
    
    layer_id_in_cfg = 0
    start_mask = torch.ones(3)
    end_mask = cfg_mask[layer_id_in_cfg]
    # zip是聚合迭代器,可以访问两个迭代器中的内容
    for [m0, m1] in zip(model.modules(), newmodel.modules()):
        if isinstance(m0, nn.BatchNorm2d):
            # 把shape中为1的维度去掉
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            m1.weight.data = m0.weight.data[idx1].clone()
            m1.bias.data = m0.bias.data[idx1].clone()
            m1.running_mean = m0.running_mean[idx1].clone()
            m1.running_var = m0.running_var[idx1].clone()
            layer_id_in_cfg += 1
            start_mask = end_mask.clone()
            if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                end_mask = cfg_mask[layer_id_in_cfg]
        elif isinstance(m0, nn.Conv2d):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            print('In shape: {:d} Out shape:{:d}'.format(idx0.shape[0], idx1.shape[0]))
            w = m0.weight.data[:, idx0, :, :].clone()
            w = w[idx1, :, :, :].clone()
            m1.weight.data = w.clone()
            # m1.bias.data = m0.bias.data[idx1].clone()
        elif isinstance(m0, nn.Linear):
            idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
            m1.weight.data = m0.weight.data[:, idx0].clone()
    
    
    torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, args.save)
    
    print(newmodel)
    model = newmodel
    test()
# vgg.py

import torch
import torch.nn as nn
from torch.autograd import Variable
import math  # init


class vgg(nn.Module):

    def __init__(self, dataset='cifar10', init_weights=True, cfg=None):
        super(vgg, self).__init__()
        if cfg is None:
            cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
        self.feature = self.make_layers(cfg, True)

        if dataset == 'cifar100':
            num_classes = 100
        elif dataset == 'cifar10':
            num_classes = 10
        self.classifier = nn.Linear(cfg[-1], num_classes)
        if init_weights:
            self._initialize_weights()

    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


if __name__ == '__main__':
    net = vgg()
    x = Variable(torch.FloatTensor(16, 3, 40, 40))
    y = net(x)
    print(y.data.shape)

运行流程:

Trained with Sparsity:python main.py -sr --s 0.0001

Pruned:python prune.py --model model_best.pth.tar --save pruned.pth.tar --percent 0.7

Fine-tuned:python main.py -refine pruned.pth.tar --epochs 10

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值