深度学习模型剪枝: Pcdet-PointPillars 剪枝流程及结果

1.Pcdet-PointPillars原始模型结构

     网络部分包含4部分:
    (1)PillarVFE
    (2)PointPillarScatter
    (3)BaseBEVBackbone
    (4)AnchorHeadSingle

主要对BaseBEVBackbone部分剪枝,BaseBEVBackbone网络结构图如下:
在这里插入图片描述
具体如下:

  (backbone_2d): BaseBEVBackbone(
    (blocks): ModuleList(
      (0): Sequential(
        (0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (2): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (3): ReLU()
        (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (8): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (9): ReLU()
        (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (11): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (12): ReLU()
      )
      (1): Sequential(
        (0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (2): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (3): ReLU()
        (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (8): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (9): ReLU()
        (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (11): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (12): ReLU()
        (13): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (14): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (15): ReLU()
        (16): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (17): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (18): ReLU()
      )
      (2): Sequential(
        (0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (2): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (3): ReLU()
        (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (6): ReLU()
        (7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (8): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (9): ReLU()
        (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (11): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (12): ReLU()
        (13): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (14): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (15): ReLU()
        (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (17): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (18): ReLU()
      )
    )
    (deblocks): ModuleList(
      (0): Sequential(
        (0): ConvTranspose2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (1): Sequential(
        (0): ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (2): Sequential(
        (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(4, 4), bias=False)
        (1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        (2): ReLU()
      )
    )
  )

2.剪枝

2.1稀疏训练
    对BN层的参数进行诱导,让大部分参数趋于零,降低剪枝对模型精度的影响

loss.backward()
updateBN(model)
optimizer.step()
def updateBN(model):
    s = 0.0001
    for m in model.modules():
        if isinstance(m, torch.nn.BatchNorm2d):
            m.weight.grad.data.add_(s*torch.sign(m.weight.data))  # L1

2.2对稀疏训练后的模型剪枝-Network_Slimming

(1)根据剪枝率(percent)计算阈值

    total = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            total += m.weight.data.shape[0]
    bn = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            size = m.weight.data.shape[0]
            bn[index:(index+size)] = m.weight.data.abs().clone()
            index += size
    y, i = torch.sort(bn)
    thre_index = int(total * 0.7)
    thre = y[thre_index]

(2)生成cfg_index(通道剪枝个数索引列表)与cfg_mask

    pruned = 0
    cfg_index = []
    cfg_mask = []
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.BatchNorm2d):
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.cpu().gt(thre).float().cuda()
            #pdb.set_trace()
           
            pruned = pruned + mask.shape[0] - torch.sum(mask)
            m.weight.data.mul_(mask)
            m.bias.data.mul_(mask)
            cfg_index.append(int(torch.sum(mask)))
            cfg_mask.append(mask.clone())
            print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
                format(k, mask.shape[0], int(torch.sum(mask))))
        elif isinstance(m, nn.MaxPool2d):
            cfg_index.append('M')

(3)对不想剪枝的bn层,cfg_mask该bn层参数全部置1

例如:

cfg_mask[0][:]=1   ##对第一个bn不剪枝

注:1)应该有更好的办法,具体问题具体分析,现在只是实现了
       2)有很多层不能剪枝,请注意

(4)根据cfg_index构建剪枝后模型框架

newmodel = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=test_set, cfg_index=cfg_index)
newmodel = newmodel.to(device='cuda:0')

注意,此处的build_network需要改写,我主要是剪枝BaseBEVBackbone,所以将此模块的每个卷积层的输入输出尺寸与cfg_index对应,如下:

class BaseBEVBackbone(nn.Module):
    def __init__(self, model_cfg, input_channels, cfg_index=None):
        super().__init__()
        self.model_cfg = model_cfg

        if self.model_cfg.get('LAYER_NUMS', None) is not None:
            assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
            layer_nums = self.model_cfg.LAYER_NUMS
            layer_strides = self.model_cfg.LAYER_STRIDES
            num_filters = self.model_cfg.NUM_FILTERS
        else:
            layer_nums = layer_strides = num_filters = []

        if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:
            assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
            num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
            upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
        else:
            upsample_strides = num_upsample_filters = [] 

        num_levels = len(layer_nums)
        c_in_list = [input_channels, *num_filters[:-1]]
        self.blocks = nn.ModuleList()
        self.deblocks = nn.ModuleList()
        if cfg_index is None:
            cfg_index = [64, 64, 64, 64, 128, 128, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 128, 128, 128]
        cfg=cfg_index
        for idx in range(num_levels): 
            if idx == 0:
                cur_layers = [
                    nn.ZeroPad2d(1),
                    nn.Conv2d(
                         64,64, kernel_size=3,
                        stride=layer_strides[idx], padding=0, bias=False
                    ),
                    nn.BatchNorm2d(64, eps=1e-3, momentum=0.01),
                    nn.ReLU()
                ]
                for k in range(3):
                    if k ==0:
                        cur_layers.extend([
                            nn.Conv2d(64, cfg[k+1], kernel_size=3, padding=1, bias=False),
                            nn.BatchNorm2d(cfg[k+1], eps=1e-3, momentum=0.01),
                            nn.ReLU()
                        ])
                    if k ==1:
                        cur_layers.extend([
                            nn.Conv2d(cfg[k+0], cfg[k+1], kernel_size=3, padding=1, bias=False),
                            nn.BatchNorm2d(cfg[k+1], eps=1e-3, momentum=0.01),
                            nn.ReLU()
                        ])
                    if k ==2:
                        cur_layers.extend([
                            nn.Conv2d(cfg[k+0], 64, kernel_size=3, padding=1, bias=False),
                            nn.BatchNorm2d(64, eps=1e-3, momentum=0.01),
                            nn.ReLU()
                        ])
                    
            elif idx ==1 :
                cur_layers = [
                    nn.ZeroPad2d(1),
                    nn.Conv2d(
                        64, cfg[4], kernel_size=3,
                        stride=layer_strides[idx], padding=0, bias=False
                    ),
                    nn.BatchNorm2d(cfg[4], eps=1e-3, momentum=0.01),
                    nn.ReLU()
                ]
                for k in range(5):
                    if k ==4:
                        cur_layers.extend([
                            nn.Conv2d(cfg[k+4], 128, kernel_size=3, padding=1, bias=False),
                            nn.BatchNorm2d(128, eps=1e-3, momentum=0.01),
                            nn.ReLU()
                        ])
                    else:
                        cur_layers.extend([
                            nn.Conv2d(cfg[k+4], cfg[k+5], kernel_size=3, padding=1, bias=False),
                            nn.BatchNorm2d(cfg[k+5], eps=1e-3, momentum=0.01),
                            nn.ReLU()
                        ])
            elif idx ==2 :
                cur_layers = [
                    nn.ZeroPad2d(1),
                    nn.Conv2d(
                        128, cfg[10], kernel_size=3,
                        stride=layer_strides[idx], padding=0, bias=False
                    ),
                    nn.BatchNorm2d(cfg[10], eps=1e-3, momentum=0.01),
                    nn.ReLU()
                ]
                for k in range(5):
                    if k==4:
                        cur_layers.extend([
                            nn.Conv2d(cfg[k+10], 256, kernel_size=3, padding=1, bias=False),
                            nn.BatchNorm2d(256, eps=1e-3, momentum=0.01),
                            nn.ReLU()
                        ])
                    else:
                        cur_layers.extend([
                            nn.Conv2d(cfg[k+10], cfg[k+11], kernel_size=3, padding=1, bias=False),
                            nn.BatchNorm2d(cfg[k+11], eps=1e-3, momentum=0.01),
                            nn.ReLU()
                        ])
            self.blocks.append(nn.Sequential(*cur_layers))
            if len(upsample_strides) > 0:
                stride = upsample_strides[idx]
                if stride >= 1:
                    self.deblocks.append(nn.Sequential(
                        nn.ConvTranspose2d(
                            num_filters[idx], num_upsample_filters[idx],
                            upsample_strides[idx],
                            stride=upsample_strides[idx], bias=False
                        ),
                        nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
                        nn.ReLU()
                    ))
                else:
                    stride = np.round(1 / stride).astype(np.int)
                    self.deblocks.append(nn.Sequential(
                        nn.Conv2d(
                            num_filters[idx], num_upsample_filters[idx],
                            stride,
                            stride=stride, bias=False
                        ),
                        nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
                        nn.ReLU()
                    ))

        c_in = sum(num_upsample_filters)
        if len(upsample_strides) > num_levels:
            self.deblocks.append(nn.Sequential(
                nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),
                nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
                nn.ReLU(),
            ))

        self.num_bev_features = c_in


        ########test

        
        
    def forward(self, data_dict):
        """
        Args:
            data_dict:
                spatial_features
        Returns:
        """
        spatial_features = data_dict['spatial_features'] 
        ups = []
        ret_dict = {}
        x = spatial_features
           
        for i in range(len(self.blocks)):         
            x = self.blocks[i](x)            
            stride = int(spatial_features.shape[2] / x.shape[2])
            ret_dict['spatial_features_%dx' % stride] = x
            if len(self.deblocks) > 0:
                ups.append(self.deblocks[i](x))
            else:
                ups.append(x)

        if len(ups) > 1:
            x = torch.cat(ups, dim=1)
        elif len(ups) == 1:
            x = ups[0]

        if len(self.deblocks) > len(self.blocks):
            x = self.deblocks[-1](x)

        data_dict['spatial_features_2d'] = x

        return data_dict

(5)对conv层及bn层参数进行剪枝

    old_modules = list(model.modules())
    new_modules = list(newmodel.modules())
    
    layer_id_in_cfg = 0
    start_mask = torch.ones(64)
    end_mask = cfg_mask[layer_id_in_cfg]
    conv_count = 0
    bn_count = 0
    for layer_id in range(len(old_modules)):
        
        m0 = old_modules[layer_id]
        m1 = new_modules[layer_id]
        #print("old_modules  is: ", old_modules)
        if isinstance(m0, nn.BatchNorm2d):
            idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
            if idx1.size == 1:
                idx1 = np.resize(idx1,(1,))

            if bn_count == 0 :
                # If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
                m1.weight.data = m0.weight.data.clone()
                m1.bias.data = m0.bias.data.clone()
                m1.running_mean = m0.running_mean.clone()
                m1.running_var = m0.running_var.clone()
                bn_count += 1
                layer_id_in_cfg += 1
                start_mask = end_mask.clone()
                if layer_id_in_cfg < len(cfg_mask):
                    end_mask = cfg_mask[layer_id_in_cfg]
            else:
                bn_count += 1
                m1.weight.data = m0.weight.data[idx1.tolist()].clone()
                m1.bias.data = m0.bias.data[idx1.tolist()].clone()
                m1.running_mean = m0.running_mean[idx1.tolist()].clone()
                m1.running_var = m0.running_var[idx1.tolist()].clone()
                layer_id_in_cfg += 1
                start_mask = end_mask.clone()
                if layer_id_in_cfg < len(cfg_mask):  # do not change in Final FC
                    end_mask = cfg_mask[layer_id_in_cfg]
        elif isinstance(m0, nn.Conv2d):
            if conv_count == 0:
                m1.weight.data = m0.weight.data.clone()
                conv_count += 1
                continue
            if layer_id == (len(old_modules)-1):
                m1.weight.data = m0.weight.data.clone()
                continue
            if isinstance(old_modules[layer_id+1], nn.BatchNorm2d):
                # This convers the convolutions in the residual block.
                # The convolutions are either after the channel selection layer or after the batch normalization layer.
                conv_count += 1
                idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))

                idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
                print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
                if idx0.size == 1:
                    idx0 = np.resize(idx0, (1,))
                if idx1.size == 1:
                    idx1 = np.resize(idx1, (1,))
                w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
    
                # If the current convolution is not the last convolution in the residual block, then we can change the 
                # number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
                w1 = w1[idx1.tolist(), :, :, :].clone()
                m1.weight.data = w1.clone()
    torch.save({'cfg': cfg, 'model_state': newmodel.state_dict()}, os.path.join('./', 'pruned_90.pth'))

(6)使用新模型,并load参数,测试效果

3.测试步骤

首先数据预处理:

python -m pcdet.datasets.kitti.kitti_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_dataset.yaml

然后:
(1).环境位置:阵列g03,zxw_compression容器,/data/OpenPCDet-master/tools
(2).运行命令:
    测试:

CUDA_VISIBLE_DEVICES=2 python test.py --cfg_file cfgs/kitti_models/pointpillar.yaml --batch_size 1 --ckpt checkpoint_epoch_90.pth

    训练:

python train.py --cfg_file cfgs/kitti_models/pointpillar.yaml --batch_size 8 --epochs 100

(3)每次切换一个OpenPCDet-master,需要运行命令

python setup.py develop

4.剪枝结果

在这里插入图片描述
在这里插入图片描述

5.总结

(1)不是所有的模型都能用剪枝来加速,很多网络层没有BN,或者可剪枝的层数过少,增速不明显;
(2)从结果来看,PCDET中的pointpillars网络部分耗时很少,主要时间浪费在后处理中的NMS模块,还未深入研究此模块耗时原因;
(3)剪枝Backbone2d层会减少后处理速度,有待研究;
(4)对于有大量conv2d+bn组合的网络结构,网络层数较多的,例如resnet152,可以采用剪枝来加速。注:本文所说的剪枝,指的是根据bn参数,对通道剪枝,不涉及其他剪枝

  • 9
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 17
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值