CFPNet:用于实时语义分割的通道特征金字塔

论文地址:CFPNet: Channel-wise Feature Pyramid for Real-Time Semantic Segmentation

代码地址: https://github.com/chukai123/CFPNet

目录

1、摘要

2、本文的主要贡献

3、background

3.1、Inception module

4、the proposed method

4.1、CFP模块

 5、实验部分

5.1、数据集

5.2、实验结果比较

 6、模型代码


1、摘要

本文为了实现更好的性能,模型尺寸和推断速度,提出了channel-wise feature pyramid(CFP)模块。并且基于CFP模块,构建了CFPNet用于实时语义分割,其中采用了一系列dliate卷积通道来提取有效特征。

Cityscapes数据集中,CFPNet取得了70.1%的class-wise mIoU,并且只有0.55亿参数和2.5 MB内存。推断速度可以在单个rtx 2080ti gpu上达到30 fps,图像为1024×2048像素。

class-wise mIoU:在cityscapes中,class表示19个小类别,而mIoU表示先计算每个类别的IoU,然后对所有类别的IoU取平均即可;

category-wise (mIoU):一般是指大类别;

mIoU:针对语义分割的一个评估指标,平均交并比Mean Intersection over Union:

从上图可以看出,IoU就是分子中重叠的蓝色部分/分母中的蓝色部分减去重叠的蓝色块的比值,然后对每个类别的IoU取平均即可得到mIoU。公式如下:

mIoU=\frac{1}{k}\sum_{i=1}^{K}\frac{P\bigcap G }{P\bigcup G}

其中K表示类别个数,P表示预测集,G表示真实集;

mIoU代码实现:

(1)可以看作是分类任务,借助混淆矩阵计算mIoU

from sklearn.metrics import confusion_matrix
import numpy as np

def miou(y_true, y_pred):
    # y_true表示真实值,y_pred表示预测
    com = confusion_matrix(y_true, y_pred)
    TP = np.diag(cm) # 混淆矩阵中的对角线部分,表示预测对的数量
    FP = com.sum(axis=0) - TP 
    FN = com.sum(axis=1) - TP
    return np.mean(TP / (FN + FP + TP + np.finfo(float).eps))

(2)numpy计算版本:可以参考https://github.com/dilligencer-zrj/code_zoo/blob/master/compute_mIOU

#设标签宽W,长H
def fast_hist(a, b, n):#a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测特征图,形状(H×W,);n是类别数目
    k = (a > 0) & (a <= n) #k是一个一维bool数组,形状(H×W,);目的是找出标签中需要计算的类别(去掉了背景),假设0是背景
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)


def per_class_iu(hist):#分别为每个类别(在这里是19类)计算mIoU,hist的形状(n, n)
    '''
	核心代码
	'''
    return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))#矩阵的对角线上的值组成的一维数组/矩阵的所有元素之和,返回值形状(n,)

def compute_mIoU(pred,label,n_classes = args.num_class):
    hist = np.zeros((num_classes, n_classes))#hist初始化为全零,在这里的hist的形状是[n_classes, n_classes]
    hist += fast_hist(label.flatten(), pred.flatten(), n_classes) #对一张图片计算 n_classes×n_classes 的hist矩阵,并累加
    
    mIoUs = per_class_iu(hist)#计算逐类别mIoU值
    for ind_class in range(n_classes):#逐类别输出一下mIoU值
        print(str(round(mIoUs[ind_class] * 100, 2)))
    print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)))#在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值
    return mIoUs

FPS:frame per second,检测器每秒能处理的图像张数。就是跟踪算法每秒钟给出多少张图片的跟踪结果。实时性一般fps>=30就表示具有实时性了。fps越高表示效率越高。省机器,省钱。

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import os
import time

import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDistributedDataParallel
from mmcv.runner import init_dist, load_checkpoint, wrap_fp16_model

from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)
from mmdet.models import build_detector
from mmdet.utils import update_data_root


def parse_args():
    parser = argparse.ArgumentParser(description='MMDet benchmark a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument(
        '--repeat-num',
        type=int,
        default=1,
        help='number of repeat times of measurement for averaging the results')
    parser.add_argument(
        '--max-iter', type=int, default=2000, help='num of max iter')
    parser.add_argument(
        '--log-interval', type=int, default=50, help='interval of logging')
    parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args


def measure_inference_speed(cfg, checkpoint, max_iter, log_interval,
                            is_fuse_conv_bn):
    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    cfg.model.pretrained = None
    cfg.data.test.test_mode = True

    # build the dataloader
    samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
    if samples_per_gpu > 1:
        # Replace 'ImageToTensor' to 'DefaultFormatBundle'
        cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(
        dataset,
        samples_per_gpu=1,
        # Because multiple processes will occupy additional CPU resources,
        # FPS statistics will be more unstable when workers_per_gpu is not 0.
        # It is reasonable to set workers_per_gpu to 0.
        workers_per_gpu=0,
        dist=True,
        shuffle=False)

    # build the model and load checkpoint
    cfg.model.train_cfg = None
    model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
    fp16_cfg = cfg.get('fp16', None)
    if fp16_cfg is not None:
        wrap_fp16_model(model)
    load_checkpoint(model, checkpoint, map_location='cpu')
    if is_fuse_conv_bn:
        model = fuse_conv_bn(model)

    model = MMDistributedDataParallel(
        model.cuda(),
        device_ids=[torch.cuda.current_device()],
        broadcast_buffers=False)
    model.eval()

    # the first several iterations may be very slow so skip them
    num_warmup = 5
    pure_inf_time = 0
    fps = 0

    # benchmark with 2000 image and take the average
    for i, data in enumerate(data_loader):

        torch.cuda.synchronize()
        start_time = time.perf_counter()

        with torch.no_grad():
            model(return_loss=False, rescale=True, **data)

        torch.cuda.synchronize()
        elapsed = time.perf_counter() - start_time

        if i >= num_warmup:
            pure_inf_time += elapsed
            if (i + 1) % log_interval == 0:
                fps = (i + 1 - num_warmup) / pure_inf_time
                print(
                    f'Done image [{i + 1:<3}/ {max_iter}], '
                    f'fps: {fps:.1f} img / s, '
                    f'times per image: {1000 / fps:.1f} ms / img',
                    flush=True)

        if (i + 1) == max_iter:
            fps = (i + 1 - num_warmup) / pure_inf_time
            print(
                f'Overall fps: {fps:.1f} img / s, '
                f'times per image: {1000 / fps:.1f} ms / img',
                flush=True)
            break
    return fps


def repeat_measure_inference_speed(cfg,
                                   checkpoint,
                                   max_iter,
                                   log_interval,
                                   is_fuse_conv_bn,
                                   repeat_num=1):
    assert repeat_num >= 1

    fps_list = []

    for _ in range(repeat_num):
        #
        cp_cfg = copy.deepcopy(cfg)

        fps_list.append(
            measure_inference_speed(cp_cfg, checkpoint, max_iter, log_interval,
                                    is_fuse_conv_bn))

    if repeat_num > 1:
        fps_list_ = [round(fps, 1) for fps in fps_list]
        times_pre_image_list_ = [round(1000 / fps, 1) for fps in fps_list]
        mean_fps_ = sum(fps_list_) / len(fps_list_)
        mean_times_pre_image_ = sum(times_pre_image_list_) / len(
            times_pre_image_list_)
        print(
            f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, '
            f'times per image: '
            f'{times_pre_image_list_}[{mean_times_pre_image_:.1f}] ms / img',
            flush=True)
        return fps_list

    return fps_list[0]


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)

    # update data root according to MMDET_DATASETS
    update_data_root(cfg)

    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    if args.launcher == 'none':
        raise NotImplementedError('Only supports distributed mode')
    else:
        init_dist(args.launcher, **cfg.dist_params)

    repeat_measure_inference_speed(cfg, args.checkpoint, args.max_iter,
                                   args.log_interval, args.fuse_conv_bn,
                                   args.repeat_num)


if __name__ == '__main__':
    main()

2、本文的主要贡献

  1. 提出了一个结合了 Inception 模块和空洞卷积(dliate)的模块,被称为 Channel-wise Feature Pyramid (CFP) 模块。 该模块联合提取各种尺寸的特征图和上下文信息,显著减少参数数量和模型尺寸。
  2. 基于 CFP 模块设计了 Channel-wise Feature Pyramid Network (CFPNet)。 它比现有的最先进的实时语义分割网络具有更少的参数和更好的性能。
  3. 在没有任何上下文模块、预训练模型或后处理的情况下,在 Cityscapes 和 CamVid 基准测试中都取得了极具竞争力的结果。 使用更少的参数,所提出的 CFPNet 大大优于现有的分割网络。 它可以在单个 RTX 2080Ti GPU 上以 30 FPS 处理高分辨率图像 (1024×2048),在 Cityscapes 测试数据集上产生 70.1% 的class-wise和 87.4% 的category-wise平均交并比 (mIoU) 55 万个参数。

多尺度卷积:

多尺度卷积层就是用不同大小的卷积核对某一时刻所得到的特征图进行卷积操作,得到新的大小不同的特征图,之后针对不同大小的特征图上采样到输入特征图的大小。也就是说,多尺度卷积层不会改变原有特征图的大小,只是通过不同卷积核的卷积操作,丰富了图像的特征,从全局的视角对图像中的感兴趣的特征信息进行编码解码,进而提高图像的分割性能。
参考:https://zhuanlan.zhihu.com/p/451122397
 

dliate convolution:空洞卷积

具有单一扩张率的空洞卷积可以提取全局信息,但可能会丢失局部特征。许多模型都采用空洞卷积来构建空间特征金字塔来提取多尺度特征。本文在CFP模块的每个通道中都采用了空洞卷积。

具有dilation rate为r的n×n的空洞卷积核的有效大小为:[𝑟(𝑛 − 1)+ 1]^2 

Dilated Convolution with a 3 x 3 kernel and dilation rate 2

  • a是普通的卷积过程(dilation rate = 1),卷积后的感受野为3
  • b是dilation rate = 2的空洞卷积,卷积后的感受野为5
  • c是dilation rate = 3的空洞卷积,卷积后的感受野为8

参考知乎:https://www.zhihu.com/search?type=content&q=%E7%A9%BA%E6%B4%9E%E5%8D%B7%E7%A7%AF

3、background

3.1、Inception module

最原始的Inception module

 最开始的Inception提出了并行结构,包含1×1,3×3和5×5卷积核以获得多尺度特征图。但是大卷积核会带来巨大的计算成本。因此后续版本引入了因子分解以减少参数量。

Inception-V2

 因子分解由两部分组成:小卷积和非对称卷积。如5×5卷积运算符由2个3×3卷积替换。如果将标准卷积因子化为3×1卷积,则1×3卷积可以用相同数量的过滤器保存33%的参数。

CFP模块受到因式分解方法的启发。CFP模块化采用小卷积方法的因子分解,以简化单个CNN通道的inception-like模块。然后,用非对称卷积技术减少了该通道的参数。分解大大减少了计算量,同时允许模块从一系列大小的感受野中学习特征。

Feature Pyramid channel

4、the proposed method

 结合表1看图4,对于一幅输入图像,首先经过3个卷积作为最开始的特征提取器,即表1中的No 1-3,stride分别设为2,1,1,特征维度为32。

然后采用一个down-sampling方法,对应No 4,该方法由一个步长为2的3×3卷积和一个2×2的最大池化组成,维度为64,并且整个过程做了3次下采样,即1/2到1/8;

在第一个、第二个最大池化层和最后的 1×1 卷积之前,使用skip connection来注入调整大小的输入图像,为分割网络提供额外的信息。

对于N0 5-6采用2个CFP模块,这个部分表示图4中的CFP-1,n,其中n表示重复次数n=2,维度为64,同时在CFP-1中采用dilation rate=2,对于r_{k}=2;

然后再采用down-sampling方法,对应No 7,紧接着是CFP-2,m,其中m表示重复次数m=6,而dilation rate分别设置为4,8,16;

最后使用 1 × 1 卷积来激活最终的特征图和一个简单的解码器——双线性插值来生成最终的分割掩码。 所有卷积之后都是 PReLU 激活函数和批量归一化。 因为在浅层网络中,PReLU 比 ReLU 取得了更好的性能。

面试常问的1×1卷积的作用?

1、维度升降主要于conv层的channel设置有关,和kernel_size大小无关;实际中设置的卷积核为1x1,目的更可能是为了降低参数个数,减少训练成本;

2、加入非线性。在NIN中其加入了1x1的conv层,由传统的conv升级为mlpconv的转变,使之由单纯的线性变换,变为复杂的feature map之间的线性组合,从而实现特征的高度抽象过程。这一过程视为由线性变换为非线性,提高抽象程度。而非加入激活函数的作用。

参考:https://www.zhihu.com/search?type=content&q=1%C3%971%E5%8D%B7%E7%A7%AF%E7%9A%84%E4%BD%9C%E7%94%A8

为什么采用双线性插值方法?它的作用是啥?


插值可以理解为用已知的像素点来获得未知的点。(比较通俗)

单线性插值

而双线性插值则是利用原图中的4个像素点得到新图中的1个像素点。

具体过程:分别在x和y轴方向计算3次单线性插值,如下图所示,在x轴方向,首先通过两次线性插值得到R2和R1,然后根据得到的R1和R2在y轴方向计算一次线性插值得到P;

参考: 图解双线性插值 - 搜索结果 - 知乎

回到文章中来,图4从左到右依次看,主要是CFP模块,接下来看这一部分

4.1、CFP模块

 图(a)原始的CFP模块,(b)最终版本的CFP模块

 在图(a)中,CFP模块由K个具有不同dilation rate的FP通道组成,对input(从高维映射到低维)采用1×1卷积从M维降到M/K维度;然后是第一个到第三个的维度不对称块是𝑀/4𝐾,𝑀/4𝐾和𝑀/2𝐾。多个 FP 通道设置为具有不同扩张率(dilation rate)的并行结构。 然后将所有特征映射连接到输入的维度,并使用另一个 1×1 卷积来激活输出。

在图(b)中,采用分层特征融合(HFF)解决网格伪影。 从第二个通道开始,我们采用求和运算逐步组合特征图,然后将它们连接起来以构建最终的分层特征图。 最后,减少了网格伪影的影响,如下图所示。 CFP 模块的最终版本如图 (b) 所示。

伪影是指原本被扫描物体并不存在而在图像上却出现的各种形态的影像。

空洞卷积会导致最终红蓝绿黄四个像素块,是由独立的像素(前两幅图对应颜色的像素)形成的,这种现象被称作网格伪影。

 参考:gridding artifacts(网格伪影)_百分之八的博客-CSDN博客_网格伪影

 5、实验部分

5.1、数据集

Cityscapes:主要有两种标签数据:精标签和粗标签。官网地址:https://www.cityscapes-dataset.com/

它包含5000个精细注释和20000个粗注释图像。该数据集是从不同季节和天气的 50 个不同城市捕获的。对于精细注释集,它包含2975个train,500个valid和1525个test图像。原始图像的分辨率为1024×2048。整个数据包含19个类别,分为7类(例如,车辆,卡车和公共汽车属于车辆类别)。

数据集下载链接:https://pan.baidu.com/s/1jH9GUDX4grcEoDNLsWPKGw
提取码:aChQ

Fine annotations
Coarse annotations

CamVid:

CamVid也是城市场景数据,用于自动驾驶。它包括701张图像,分辨率为720×960,train包含367,valid包含101,test包含233。它包含11种,本文在训练前将这些图像调整到360×480。

 官方地址:Object Recognition in Video Dataset

5.2、实验结果比较

在Accuracy and model size方面的效果比较

Network size and class-wise mIoU accuracy

Network size vs category-wise mIoU accuracy

 在mIoU、FPS、Parameters方面的比较如下:

 6、模型代码

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ["CFPNet"]

class DeConv(nn.Module):
    def __init__(self, nIn, nOut, kSize, stride, padding, output_padding, dilation=(1, 1), groups=1, bn_acti=False, bias=False):
        super().__init__()

        self.bn_acti = bn_acti

        self.conv = nn.ConvTranspose2d(nIn, nOut, kernel_size=kSize,
                              stride=stride, padding=padding, output_padding=output_padding,
                              dilation=dilation, groups=groups, bias=bias)

        if self.bn_acti:
            self.bn_prelu = BNPReLU(nOut)

    def forward(self, input):
        output = self.conv(input)

        if self.bn_acti:
            output = self.bn_prelu(output)

        return output
    
    
# 自定义的conv类
class Conv(nn.Module):
    def __init__(self, nIn, nOut, kSize, stride, padding, dilation=(1, 1), groups=1, bn_acti=False, bias=False):
        super().__init__()

        self.bn_acti = bn_acti

        self.conv = nn.Conv2d(nIn, nOut, kernel_size=kSize,
                              stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)

        if self.bn_acti:
            self.bn_prelu = BNPReLU(nOut)

    def forward(self, input):
        output = self.conv(input)

        if self.bn_acti:
            output = self.bn_prelu(output)

        return output

# BN+PReLU,这里的Prelu的效果是优于relu的
class BNPReLU(nn.Module):
    def __init__(self, nIn):
        super().__init__()
        self.bn = nn.BatchNorm2d(nIn, eps=1e-3)
        self.acti = nn.PReLU(nIn)

    def forward(self, input):
        output = self.bn(input)
        output = self.acti(output)

        return output



class CFPModule(nn.Module):
    def __init__(self, nIn, d=1, KSize=3,dkSize=3):
        super().__init__()
        
        self.bn_relu_1 = BNPReLU(nIn)
        self.bn_relu_2 = BNPReLU(nIn)
        self.conv1x1_1 = Conv(nIn, nIn // 4, KSize, 1, padding=1, bn_acti=True)
        
        self.dconv3x1_4_1 = Conv(nIn // 4, nIn // 16, (dkSize, 1), 1,
                              padding=(1*d+1, 0), dilation=(d+1,1), groups = nIn //16, bn_acti=True)
        self.dconv1x3_4_1 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
                              padding=(0, 1*d+1), dilation=(1,d+1), groups = nIn //16, bn_acti=True)
        
        self.dconv3x1_4_2 = Conv(nIn // 16, nIn // 16, (dkSize, 1), 1,
                              padding=(1*d+1, 0), dilation=(d+1,1),groups = nIn //16, bn_acti=True)
        self.dconv1x3_4_2 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
                              padding=(0, 1*d+1), dilation=(1,d+1),groups = nIn //16, bn_acti=True)        
        
        self.dconv3x1_4_3 = Conv(nIn // 16, nIn // 8, (dkSize, 1), 1,
                              padding=(1*d+1, 0), dilation=(d+1,1),groups = nIn //16, bn_acti=True)
        self.dconv1x3_4_3 = Conv(nIn // 8, nIn // 8, (1, dkSize), 1,
                              padding=(0, 1*d+1), dilation=(1,d+1),groups = nIn //8, bn_acti=True) 
        
        self.dconv3x1_1_1 = Conv(nIn // 4, nIn // 16, (dkSize, 1), 1,
                              padding=(1, 0),groups = nIn //16, bn_acti=True)
        self.dconv1x3_1_1 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
                              padding=(0, 1),groups = nIn //16, bn_acti=True)
        
        self.dconv3x1_1_2 = Conv(nIn // 16, nIn // 16, (dkSize, 1), 1,
                              padding=(1, 0),groups = nIn //16, bn_acti=True)
        self.dconv1x3_1_2 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
                              padding=(0, 1),groups = nIn //16, bn_acti=True)
        
        self.dconv3x1_1_3 = Conv(nIn // 16, nIn // 8, (dkSize, 1), 1,
                              padding=(1, 0),groups = nIn //16, bn_acti=True)
        self.dconv1x3_1_3 = Conv(nIn // 8, nIn // 8, (1, dkSize), 1,
                              padding=(0, 1),groups = nIn //8, bn_acti=True)
        
        
        self.dconv3x1_2_1 = Conv(nIn // 4, nIn // 16, (dkSize, 1), 1,
                              padding=(int(d/4+1), 0), dilation=(int(d/4+1),1), groups = nIn //16, bn_acti=True)
        self.dconv1x3_2_1 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
                              padding=(0, int(d/4+1)), dilation=(1,int(d/4+1)), groups = nIn //16, bn_acti=True)
        
        self.dconv3x1_2_2 = Conv(nIn // 16, nIn // 16, (dkSize, 1), 1,
                              padding=(int(d/4+1), 0), dilation=(int(d/4+1),1),groups = nIn //16, bn_acti=True)
        self.dconv1x3_2_2 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
                              padding=(0, int(d/4+1)), dilation=(1,int(d/4+1)),groups = nIn //16, bn_acti=True)        
        
        self.dconv3x1_2_3 = Conv(nIn // 16, nIn // 8, (dkSize, 1), 1,
                              padding=(int(d/4+1), 0), dilation=(int(d/4+1),1),groups = nIn //16, bn_acti=True)
        self.dconv1x3_2_3 = Conv(nIn // 8, nIn // 8, (1, dkSize), 1,
                              padding=(0, int(d/4+1)), dilation=(1,int(d/4+1)),groups = nIn //8, bn_acti=True)         
        
        
        
        self.dconv3x1_3_1 = Conv(nIn // 4, nIn // 16, (dkSize, 1), 1,
                              padding=(int(d/2+1), 0), dilation=(int(d/2+1),1), groups = nIn //16, bn_acti=True)
        self.dconv1x3_3_1 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
                              padding=(0, int(d/2+1)), dilation=(1,int(d/2+1)), groups = nIn //16, bn_acti=True)
        
        self.dconv3x1_3_2 = Conv(nIn // 16, nIn // 16, (dkSize, 1), 1,
                              padding=(int(d/2+1), 0), dilation=(int(d/2+1),1),groups = nIn //16, bn_acti=True)
        self.dconv1x3_3_2 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
                              padding=(0, int(d/2+1)), dilation=(1,int(d/2+1)),groups = nIn //16, bn_acti=True)        
        
        self.dconv3x1_3_3 = Conv(nIn // 16, nIn // 8, (dkSize, 1), 1,
                              padding=(int(d/2+1), 0), dilation=(int(d/2+1),1),groups = nIn //16, bn_acti=True)
        self.dconv1x3_3_3 = Conv(nIn // 8, nIn // 8, (1, dkSize), 1,
                              padding=(0, int(d/2+1)), dilation=(1,int(d/2+1)),groups = nIn //8, bn_acti=True)              
        
        self.conv1x1 = Conv(nIn, nIn, 1, 1, padding=0,bn_acti=False)
        
    def forward(self, input):
        inp = self.bn_relu_1(input)
        inp = self.conv1x1_1(inp)
        
        o1_1 = self.dconv3x1_1_1(inp)
        o1_1 = self.dconv1x3_1_1(o1_1)
        o1_2 = self.dconv3x1_1_2(o1_1)
        o1_2 = self.dconv1x3_1_2(o1_2)
        o1_3 = self.dconv3x1_1_3(o1_2)
        o1_3 = self.dconv1x3_1_3(o1_3)
        
        o2_1 = self.dconv3x1_2_1(inp)
        o2_1 = self.dconv1x3_2_1(o2_1)
        o2_2 = self.dconv3x1_2_2(o2_1)
        o2_2 = self.dconv1x3_2_2(o2_2)
        o2_3 = self.dconv3x1_2_3(o2_2)
        o2_3 = self.dconv1x3_2_3(o2_3)        
     
        o3_1 = self.dconv3x1_3_1(inp)
        o3_1 = self.dconv1x3_3_1(o3_1)
        o3_2 = self.dconv3x1_3_2(o3_1)
        o3_2 = self.dconv1x3_3_2(o3_2)
        o3_3 = self.dconv3x1_3_3(o3_2)
        o3_3 = self.dconv1x3_3_3(o3_3)               
        
        
        o4_1 = self.dconv3x1_4_1(inp)
        o4_1 = self.dconv1x3_4_1(o4_1)
        o4_2 = self.dconv3x1_4_2(o4_1)
        o4_2 = self.dconv1x3_4_2(o4_2)
        o4_3 = self.dconv3x1_4_3(o4_2)
        o4_3 = self.dconv1x3_4_3(o4_3)               
        
        
        output_1 = torch.cat([o1_1,o1_2,o1_3], 1)
        output_2 = torch.cat([o2_1,o2_2,o2_3], 1)      
        output_3 = torch.cat([o3_1,o3_2,o3_3], 1)       
        output_4 = torch.cat([o4_1,o4_2,o4_3], 1)   
        
        ad1 = output_1
        ad2 = ad1 + output_2
        ad3 = ad2 + output_3
        ad4 = ad3 + output_4
        output = torch.cat([ad1,ad2,ad3,ad4],1)
        output = self.bn_relu_2(output)
        output = self.conv1x1(output)
        
        return output+input
        
# down-sampling
class DownSamplingBlock(nn.Module):
    """
    nIn:输入通道数
    nOut:输出通道数
    """
    def __init__(self, nIn, nOut):
        super().__init__()
        self.nIn = nIn
        self.nOut = nOut

        if self.nIn < self.nOut:
            nConv = nOut - nIn
        else:
            nConv = nOut

        self.conv3x3 = Conv(nIn, nConv, kSize=3, stride=2, padding=1)
        self.max_pool = nn.MaxPool2d(2, stride=2)
        self.bn_prelu = BNPReLU(nOut)

    def forward(self, input):
        output = self.conv3x3(input)

        if self.nIn < self.nOut:
            max_pool = self.max_pool(input)
            output = torch.cat([output, max_pool], 1)

        output = self.bn_prelu(output)

        return output

# 执行三次平均池化:1/2,1/4,1/8
class InputInjection(nn.Module):
    def __init__(self, ratio):
        super().__init__()
        self.pool = nn.ModuleList()
        for i in range(0, ratio):
            self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))

    def forward(self, input):
        for pool in self.pool:
            input = pool(input)

        return input


class CFPNet(nn.Module):
    # 数据集中的类别,block_1表示第一个CFP模块的数量,block_2表示第二个CFP模块的数量
    def __init__(self, classes=11, block_1=2, block_2=6):
        super().__init__() # 继承父类的init()方法
        
        # 前三个卷积块,用以特征提取
        self.init_conv = nn.Sequential(
            Conv(3, 32, 3, 2, padding=1, bn_acti=True),
            Conv(32, 32, 3, 1, padding=1, bn_acti=True),
            Conv(32, 32, 3, 1, padding=1, bn_acti=True),
        )

        # 论文中提到的down-sample方法,采用的是平均池化方法
        self.down_1 = InputInjection(1)  # down-sample the image 1 times:1/2
        self.down_2 = InputInjection(2)  # down-sample the image 2 times:1/4
        self.down_3 = InputInjection(3)  # down-sample the image 3 times:1/8

        # BN+PReLU
        self.bn_prelu_1 = BNPReLU(32 + 3)
        # block_1中的CFP模块中的dilation_rate
        dilation_block_1 =[2,2]
        
        # CFP Block 1
        self.downsample_1 = DownSamplingBlock(32 + 3, 64)
        self.CFP_Block_1 = nn.Sequential()
        for i in range(0, block_1):
            self.CFP_Block_1.add_module("CFP_Module_1_" + str(i), CFPModule(64, d=dilation_block_1[i]))
            
        self.bn_prelu_2 = BNPReLU(128 + 3)

        # CFP Block 2
        dilation_block_2 = [4,4,8,8,16,16] #camvid #cityscapes [4,4,8,8,16,16] # [4,8,16]
        self.downsample_2 = DownSamplingBlock(128 + 3, 128)
        self.CFP_Block_2 = nn.Sequential()
        for i in range(0, block_2):
            self.CFP_Block_2.add_module("CFP_Module_2_" + str(i),
                                        CFPModule(128, d=dilation_block_2[i]))
        self.bn_prelu_3 = BNPReLU(256 + 3)

        self.classifier = nn.Sequential(Conv(259, classes, 1, 1, padding=0))

    def forward(self, input):

        output0 = self.init_conv(input)

        # 论文中的3 time down-sample
        down_1 = self.down_1(input)
        down_2 = self.down_2(input)
        down_3 = self.down_3(input)

        # 第一次concat
        output0_cat = self.bn_prelu_1(torch.cat([output0, down_1], 1))

        # CFP Block 1
        output1_0 = self.downsample_1(output0_cat)
        output1 = self.CFP_Block_1(output1_0)
        # 第二次concat
        output1_cat = self.bn_prelu_2(torch.cat([output1, output1_0, down_2], 1))

        # CFP Block 2
        output2_0 = self.downsample_2(output1_cat)
        output2 = self.CFP_Block_2(output2_0)
        # 第三次concat
        output2_cat = self.bn_prelu_3(torch.cat([output2, output2_0, down_3], 1))

        out = self.classifier(output2_cat)
        # 上采样函数,这个在opencv中也有提供
        out = F.interpolate(out, input.size()[2:], mode='bilinear', align_corners=False)

        return out
    
    

使用:

from CFPNet import CFPNet

def build_model(model_name, num_classes):
    if model_name == 'DABNet':
        return CFPNet(classes=num_classes)

注:具体细节可以参考文献!

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

kaichu2

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值