fishnet:论文阅读与代码理解

fishnet论文地址:http://papers.nips.cc/paper/7356-fishnet-a-versatile-backbone-for-image-region-and-pixel-level-prediction.pdf
fishnet源码地址(pytorch版本):https://github.com/kevin-ssy/FishNet

一、论文概述

  我们知道,对应不同的计算机视觉任务(图像分类、目标检测、语义分割、实例分割等),所需要卷积神经网络提取的特征是不一样的。以图像分类任务与语义分割任务为例。图像分类对应对图片级别的对象进行预测,比如预测一张图片属于猫还是狗。那么它所需要的特征需要更加抽象化的高层次语义特征。而语义分割任务所对应的是像素级别的预测,即预测每一个像素点属于哪一类。这种任务不仅需要语义特征,而且在此基础上还需要注重低层次的细节特征。所以说针对图片级别、区域级别和像素级别的预测任务,卷积神经网络的注重点是不一样的。
 目前而言,用于图像分类的网络,如:ResNet、DenseNet等可以直接将其作为Backbone(主干网)用于区域级和像素级的预测任务(如语义分割中,常用ResNet101作为特征提取的Backbone)。但是为语义分割、目标检测等设计的网络通常是在图像分类任务中发挥不了作用的。
 于是,作者为了设计一种通用于这些任务的卷积神经网络,设计了一种名为fishnet(因为网络的形状像一条鱼,所以命名为fishnet。写论文还是得搞点花里胡哨的东西才能中啊)的网络,它既可以用于图像分类,又可以用于目标检测、语义分割等任务。换言之,也就是说,这个网络所提取的特征是语义特征与细节特征都十分丰富的。那么下面我们来看一下,fishnet是怎么做到语义特征与细节特征并重的。

二、整体框架

 如下图,是fishnet的网络结构图。可以看到,确实,还真挺像一条鱼。
在这里插入图片描述
 整个网络分为三个部分从左至右分别命名为:鱼尾(fish tail)、鱼身(fish body)、鱼头(fish head)。鱼尾实际上就是一个resnet的结构,它负责提取语义特征。到了鱼身之后,开始使用上采样提升特征图的分辨率,并进行了跳层连接。这两个操作都是为了让网络拥有更多的细节特征。至此,如果你是要进行语义分割、目标检测等任务的话,就可以不用管鱼头部分了。你可以将鱼身的输出直接上采样到原图大小(到这里,实际上就是一个类似于FCN结构的网络,只是内部实现的细节有所不同)。然后,如果想要进行图像分类任务的话,就用最后的鱼头,下采样得到最后的score vector。下面详细的讲一下这三个部分:

  1. 鱼尾:一个resnet结构。具体结构如下图。值得注意的是,这里的结构据采用maxpooling进行下采样而不采用步长为2的卷积。
    -在这里插入图片描述
  2. 鱼身与鱼头:详细结构如下图:在这里插入图片描述  鱼尾的输出特征图经过SE block的处理后得到鱼身的输入(对应图C3)。然后将其上采样一倍后与鱼尾中对应分辨率经过Transferring Block的特征图相连。这里的Transferring Block实际上就是一个Bottleneck block。串联后送入Up-sampling & Refinement block (UR-block) 中。UR blcok顾名思义就是用来讲特征图上采样与精细化特征的。上采样我们是知道的,它对应这幅图右上角的up(.)。论文中用最近淋插值法上采样。那么怎么进行特征精细化呢?它对用M(.)与r(.)操作。其中M(.)是bottleneck block 。它将特征图的通道变为输入通道图的1/k。这里的K是个超参数,人为通过实验设定。而r(.)则是把输入特征图中的相邻k个通道求和变为一个通道。这样也得到一个通道变为输入通道图的1/k的特征图。然后对二者求和得到特征细化的结果。读到这里,你可能就理解了,所谓的特征细化其实就是一个减少通道数的过程。后续在重复上采样、串联、UR block两次后便完成了鱼身的过程。得到了一个分辨率为原图1/4大小的富含语义信息与细节信息的特征图。
      随后的鱼头的才做与鱼身中类似。只不过上采样换为下采样、UR block换为DR block。而DR block与UR block的不同之处在于:
       1)使用2x2最大池化来下采样。
       2)不使用通道缩减函数,以使得当前阶段的梯度可以直接被传送到先前的阶段。

三、代码理解

  模型主要分为三个文件:

  • **fishnet.py:**构建fishnet模型的文件。主要分为两个类和一个函数:
1class Fish(nn.Module):封装了fishnet的主要结构
2class FishNet(nn.Module):调用Fish类进行更高一层的封装
3def fish(**kwargs):fishnet.py文件的对外接口,调用该函数会返回一个Fishnet类对象
  • fish_block.py:包含一个与原始resnet中经过稍微调整的bottleneck block 类。是fishnet.py文件Fish类中构建fishnet模型的重要组件。
  • net_factory.py: 整个这三个文件的对外接口。其中包含三个函数:
1def fishnet99(**kwargs):
2def fishnet150(**kwargs):
3def fishnet201(**kwargs):

调用不同的函数可以返回不同模型大小的fishnet模型。
1) fishnet.py:

from __future__ import division
import torch
import math
from .fish_block import *

__all__ = ['fish']

class Fish(nn.Module):
    def __init__(self, block, num_cls=1000, num_down_sample=5, num_up_sample=3, trans_map=(2, 1, 0, 6, 5, 4),
                 network_planes=None, num_res_blks=None, num_trans_blks=None):
        super(Fish, self).__init__()
        self.block = block
        self.trans_map = trans_map
        self.upsample = nn.Upsample(scale_factor=2)
        self.down_sample = nn.MaxPool2d(2, stride=2)
        self.num_cls = num_cls
        self.num_down = num_down_sample
        self.num_up = num_up_sample
        self.network_planes = network_planes[1:]
        self.depth = len(self.network_planes)
        self.num_trans_blks = num_trans_blks
        self.num_res_blks = num_res_blks
        self.fish = self._make_fish(network_planes[0])

    def _make_score(self, in_ch, out_ch=1000, has_pool=False):
        bn = nn.BatchNorm2d(in_ch)
        relu = nn.ReLU(inplace=True)
        conv_trans = nn.Conv2d(in_ch, in_ch // 2, kernel_size=1, bias=False)
        bn_out = nn.BatchNorm2d(in_ch // 2)
        conv = nn.Sequential(bn, relu, conv_trans, bn_out, relu)
        if has_pool:
            fc = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True))
        else:
            fc = nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True)
        return [conv, fc]

    def _make_se_block(self, in_ch, out_ch):
        bn = nn.BatchNorm2d(in_ch)
        sq_conv = nn.Conv2d(in_ch, out_ch // 16, kernel_size=1)
        ex_conv = nn.Conv2d(out_ch // 16, out_ch, kernel_size=1)
        return nn.Sequential(bn,
                             nn.ReLU(inplace=True),
                             nn.AdaptiveAvgPool2d(1),
                             sq_conv,
                             nn.ReLU(inplace=True),
                             ex_conv,
                             nn.Sigmoid())

    def _make_residual_block(self, inplanes, outplanes, nstage, is_up=False, k=1, dilation=1):
        layers = []

        if is_up:
            layers.append(self.block(inplanes, outplanes, mode='UP', dilation=dilation, k=k))
        else:
            layers.append(self.block(inplanes, outplanes, stride=1))
        for i in range(1, nstage):
            layers.append(self.block(outplanes, outplanes, stride=1, dilation=dilation))
        return nn.Sequential(*layers)

    def _make_stage(self, is_down_sample, inplanes, outplanes, n_blk, has_trans=True,
                    has_score=False, trans_planes=0, no_sampling=False, num_trans=2, **kwargs):
        sample_block = []
        if has_score:
            sample_block.extend(self._make_score(outplanes, outplanes * 2, has_pool=False))

        if no_sampling or is_down_sample:
            res_block = self._make_residual_block(inplanes, outplanes, n_blk, **kwargs)
        else:
            res_block = self._make_residual_block(inplanes, outplanes, n_blk, is_up=True, **kwargs)

        sample_block.append(res_block)

        if has_trans:
            trans_in_planes = self.in_planes if trans_planes == 0 else trans_planes
            sample_block.append(self._make_residual_block(trans_in_planes, trans_in_planes, num_trans))

        if not no_sampling and is_down_sample:
            sample_block.append(self.down_sample)
        elif not no_sampling:  # Up-Sample
            sample_block.append(self.upsample)

        return nn.ModuleList(sample_block)

    def _make_fish(self, in_planes):
        def get_trans_planes(index):
            map_id = self.trans_map[index-self.num_down-1] - 1
            p = in_planes if map_id == -1 else cated_planes[map_id]
            return p

        def get_trans_blk(index):
            return self.num_trans_blks[index-self.num_down-1]

        def get_cur_planes(index):
            return self.network_planes[index]

        def get_blk_num(index):
            return self.num_res_blks[index]

        cated_planes, fish = [in_planes] * self.depth, []
        for i in range(self.depth):
            # even num for down-sample, odd for up-sample
            is_down, has_trans, no_sampling = i not in range(self.num_down, self.num_down+self.num_up+1),\
                                              i > self.num_down, i == self.num_down
            # is_down, has_trans, no_sampling:True False False; True False False; True False False; False False True
            # False True False; False True False; False True False; True True False;True True False; True True False
            cur_planes, trans_planes, cur_blocks, num_trans =\
                get_cur_planes(i), get_trans_planes(i), get_blk_num(i), get_trans_blk(i)
            # cur_planes, trans_planes, cur_blocks, num_trans:128 64 2 1;256 64 2 1; 512 64 6 1; 512 64 2 4
            # 512 256 1 1; 384 128 1 1; 256 64 1 1; 320 512 1 1;832 768 2 1; 1600 512 2 4

            stg_args = [is_down, cated_planes[i - 1], cur_planes, cur_blocks]
            # inplanes:64,128,256,512,1024,512,768,512,320,832,1600

            if is_down or no_sampling:
                k, dilation = 1, 1
            else:
                k, dilation = cated_planes[i - 1] // cur_planes, 2 ** (i-self.num_down-1)

            sample_block = self._make_stage(*stg_args, has_trans=has_trans, trans_planes=trans_planes,
                                            has_score=(i==self.num_down), num_trans=num_trans, k=k, dilation=dilation,
                                            no_sampling=no_sampling)
            if i == self.depth - 1:
                sample_block.extend(self._make_score(cur_planes + trans_planes, out_ch=self.num_cls, has_pool=True))
            elif i == self.num_down:
                sample_block.append(nn.Sequential(self._make_se_block(cur_planes*2, cur_planes)))

            if i == self.num_down-1:
                cated_planes[i] = cur_planes * 2
            elif has_trans:
                cated_planes[i] = cur_planes + trans_planes
            else:
                cated_planes[i] = cur_planes
            fish.append(sample_block)
        return nn.ModuleList(fish)

    def _fish_forward(self, all_feat):
        def _concat(a, b):
            return torch.cat([a, b], dim=1)

        def stage_factory(*blks):
            def stage_forward(*inputs):
                if stg_id < self.num_down:  # tail
                    tail_blk = nn.Sequential(*blks[:2])
                    # print(stg_id)
                    # print(tail_blk)
                    return tail_blk(*inputs)
                elif stg_id == self.num_down:
                    score_blks = nn.Sequential(*blks[:2])
                    score_feat = score_blks(inputs[0])
                    att_feat = blks[3](score_feat)
                    return blks[2](score_feat) * att_feat + att_feat
                else:  # refine
                    feat_trunk = blks[2](blks[0](inputs[0]))
                    feat_branch = blks[1](inputs[1])
                return _concat(feat_trunk, feat_branch)
            return stage_forward

        stg_id = 0
        # tail:
        while stg_id < self.depth:
            stg_blk = stage_factory(*self.fish[stg_id])
            if stg_id <= self.num_down:
                in_feat = [all_feat[stg_id]]
            else:
                trans_id = self.trans_map[stg_id-self.num_down-1]
                in_feat = [all_feat[stg_id], all_feat[trans_id]]

            all_feat[stg_id + 1] = stg_blk(*in_feat)
            stg_id += 1
            # loop exit
            if stg_id == self.depth:
                score_feat = self.fish[self.depth-1][-2](all_feat[-1])
                score = self.fish[self.depth-1][-1](score_feat)
                for fea in all_feat:
                    print(fea.shape)
                return score

    def forward(self, x):
        all_feat = [None] * (self.depth + 1)
        all_feat[0] = x
        return self._fish_forward(all_feat)


class FishNet(nn.Module):
    def __init__(self, block, **kwargs):
        super(FishNet, self).__init__()

        inplanes = kwargs['network_planes'][0]
        # resolution: 224x224
        self.conv1 = self._conv_bn_relu(3, inplanes // 2, stride=2)
        self.conv2 = self._conv_bn_relu(inplanes // 2, inplanes // 2)
        self.conv3 = self._conv_bn_relu(inplanes // 2, inplanes)
        self.pool1 = nn.MaxPool2d(3, padding=1, stride=2)
        # construct fish, resolution 56x56
        self.fish = Fish(block, **kwargs)
        self._init_weights()

    def _conv_bn_relu(self, in_ch, out_ch, stride=1):
        return nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=stride, bias=False),
                             nn.BatchNorm2d(out_ch),
                             nn.ReLU(inplace=True))

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.pool1(x)
        # x.Size([1, 64, 56, 56])
        score = self.fish(x)
        # 1*1 output
        out = score.view(x.size(0), -1)

        return out


def fish(**kwargs):
    return FishNet(Bottleneck, **kwargs)

2) fish_block.py:

import torch.nn as nn


class Bottleneck(nn.Module):
    def __init__(self, inplanes, planes, stride=1, mode='NORM', k=1, dilation=1):
        """
        Pre-act residual block, the middle transformations are bottle-necked
        :param inplanes:
        :param planes:
        :param stride:
        :param downsample:
        :param mode: NORM | UP
        :param k: times of additive
        """

        super(Bottleneck, self).__init__()
        self.mode = mode
        self.relu = nn.ReLU(inplace=True)
        self.k = k

        btnk_ch = planes // 4
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes, btnk_ch, kernel_size=1, bias=False)

        self.bn2 = nn.BatchNorm2d(btnk_ch)
        self.conv2 = nn.Conv2d(btnk_ch, btnk_ch, kernel_size=3, stride=stride, padding=dilation,
                               dilation=dilation, bias=False)

        self.bn3 = nn.BatchNorm2d(btnk_ch)
        self.conv3 = nn.Conv2d(btnk_ch, planes, kernel_size=1, bias=False)

        if mode == 'UP':
            self.shortcut = None
        elif inplanes != planes or stride > 1:
            self.shortcut = nn.Sequential(
                nn.BatchNorm2d(inplanes),
                self.relu,
                nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
            )
        else:
            self.shortcut = None

    def _pre_act_forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        if self.mode == 'UP':
            residual = self.squeeze_idt(x)
        elif self.shortcut is not None:
            residual = self.shortcut(residual)

        out += residual

        return out

    def squeeze_idt(self, idt):
        n, c, h, w = idt.size()
        return idt.view(n, c // self.k, self.k, h, w).sum(2)

    def forward(self, x):
        out = self._pre_act_forward(x)
        return out

3) fish_block.py:

from models.fishnet import fish
import torch

def fishnet99(**kwargs):
    """

    :return:
    """
    net_cfg = {
        #  input size:   [224, 56, 28,  14  |  7,   7,  14,  28 | 56,   28,  14]
        # output size:   [56,  28, 14,   7  |  7,  14,  28,  56 | 28,   14,   7]
        #                  |    |    |   |     |    |    |    |    |     |    |
        'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600],
        'num_res_blks': [2, 2, 6, 2, 1, 1, 1, 1, 2, 2],
        'num_trans_blks': [1, 1, 1, 1, 1, 4],
        'num_cls': 1000,
        'num_down_sample': 3,
        'num_up_sample': 3,
    }
    cfg = {**net_cfg, **kwargs}
    return fish(**cfg)


def fishnet150(**kwargs):
    """

    :return:
    """
    net_cfg = {
        #  input size:   [224, 56, 28,  14  |  7,   7,  14,  28 | 56,   28,  14]
        # output size:   [56,  28, 14,   7  |  7,  14,  28,  56 | 28,   14,   7]
        #                  |    |    |   |     |    |    |    |    |     |    |
        'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600],
        'num_res_blks': [2, 4, 8, 4, 2, 2, 2, 2, 2, 4],
        'num_trans_blks': [2, 2, 2, 2, 2, 4],
        'num_cls': 1000,
        'num_down_sample': 3,
        'num_up_sample': 3,
    }
    cfg = {**net_cfg, **kwargs}
    return fish(**cfg)


def fishnet201(**kwargs):
    """

    :return:
    """
    net_cfg = {
        #  input size:   [224, 56, 28,  14  |  7,   7,  14,  28 | 56,   28,  14]
        # output size:   [56,  28, 14,   7  |  7,  14,  28,  56 | 28,   14,   7]
        #                  |    |    |   |     |    |    |    |    |     |    |
        'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600],
        'num_res_blks': [3, 4, 12, 4, 2, 2, 2, 2, 3, 10],
        'num_trans_blks': [2, 2, 2, 2, 2, 9],
        'num_cls': 1000,
        'num_down_sample': 3,
        'num_up_sample': 3,
    }
    cfg = {**net_cfg, **kwargs}
    return fish(**cfg)

四、总结

  • 创造性的在类FCN的网络后再次添加了卷积神经网络。这样的处理使得用于目标检测、语义分割等任务的卷积神经网络可以用于图像分类。并且充分利用到了卷积神经网络所提取到的细节信息。
  • 在网络中不再使用孤立卷积,使得深层的梯度可以直接传递到浅层。
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值