fishnet论文地址:http://papers.nips.cc/paper/7356-fishnet-a-versatile-backbone-for-image-region-and-pixel-level-prediction.pdf
fishnet源码地址(pytorch版本):https://github.com/kevin-ssy/FishNet
一、论文概述
我们知道,对应不同的计算机视觉任务(图像分类、目标检测、语义分割、实例分割等),所需要卷积神经网络提取的特征是不一样的。以图像分类任务与语义分割任务为例。图像分类对应对图片级别的对象进行预测,比如预测一张图片属于猫还是狗。那么它所需要的特征需要更加抽象化的高层次语义特征。而语义分割任务所对应的是像素级别的预测,即预测每一个像素点属于哪一类。这种任务不仅需要语义特征,而且在此基础上还需要注重低层次的细节特征。所以说针对图片级别、区域级别和像素级别的预测任务,卷积神经网络的注重点是不一样的。
目前而言,用于图像分类的网络,如:ResNet、DenseNet等可以直接将其作为Backbone(主干网)用于区域级和像素级的预测任务(如语义分割中,常用ResNet101作为特征提取的Backbone)。但是为语义分割、目标检测等设计的网络通常是在图像分类任务中发挥不了作用的。
于是,作者为了设计一种通用于这些任务的卷积神经网络,设计了一种名为fishnet(因为网络的形状像一条鱼,所以命名为fishnet。写论文还是得搞点花里胡哨的东西才能中啊)的网络,它既可以用于图像分类,又可以用于目标检测、语义分割等任务。换言之,也就是说,这个网络所提取的特征是语义特征与细节特征都十分丰富的。那么下面我们来看一下,fishnet是怎么做到语义特征与细节特征并重的。
二、整体框架
如下图,是fishnet的网络结构图。可以看到,确实,还真挺像一条鱼。
整个网络分为三个部分从左至右分别命名为:鱼尾(fish tail)、鱼身(fish body)、鱼头(fish head)。鱼尾实际上就是一个resnet的结构,它负责提取语义特征。到了鱼身之后,开始使用上采样提升特征图的分辨率,并进行了跳层连接。这两个操作都是为了让网络拥有更多的细节特征。至此,如果你是要进行语义分割、目标检测等任务的话,就可以不用管鱼头部分了。你可以将鱼身的输出直接上采样到原图大小(到这里,实际上就是一个类似于FCN结构的网络,只是内部实现的细节有所不同)。然后,如果想要进行图像分类任务的话,就用最后的鱼头,下采样得到最后的score vector。下面详细的讲一下这三个部分:
- 鱼尾:一个resnet结构。具体结构如下图。值得注意的是,这里的结构据采用maxpooling进行下采样而不采用步长为2的卷积。
- - 鱼身与鱼头:详细结构如下图:
鱼尾的输出特征图经过SE block的处理后得到鱼身的输入(对应图C3)。然后将其上采样一倍后与鱼尾中对应分辨率经过Transferring Block的特征图相连。这里的Transferring Block实际上就是一个Bottleneck block。串联后送入Up-sampling & Refinement block (UR-block) 中。UR blcok顾名思义就是用来讲特征图上采样与精细化特征的。上采样我们是知道的,它对应这幅图右上角的up(.)。论文中用最近淋插值法上采样。那么怎么进行特征精细化呢?它对用M(.)与r(.)操作。其中M(.)是bottleneck block 。它将特征图的通道变为输入通道图的1/k。这里的K是个超参数,人为通过实验设定。而r(.)则是把输入特征图中的相邻k个通道求和变为一个通道。这样也得到一个通道变为输入通道图的1/k的特征图。然后对二者求和得到特征细化的结果。读到这里,你可能就理解了,所谓的特征细化其实就是一个减少通道数的过程。后续在重复上采样、串联、UR block两次后便完成了鱼身的过程。得到了一个分辨率为原图1/4大小的富含语义信息与细节信息的特征图。
随后的鱼头的才做与鱼身中类似。只不过上采样换为下采样、UR block换为DR block。而DR block与UR block的不同之处在于:
1)使用2x2最大池化来下采样。
2)不使用通道缩减函数,以使得当前阶段的梯度可以直接被传送到先前的阶段。
三、代码理解
模型主要分为三个文件:
- **fishnet.py:**构建fishnet模型的文件。主要分为两个类和一个函数:
1、class Fish(nn.Module):封装了fishnet的主要结构
2、class FishNet(nn.Module):调用Fish类进行更高一层的封装
3、def fish(**kwargs):fishnet.py文件的对外接口,调用该函数会返回一个Fishnet类对象
- fish_block.py:包含一个与原始resnet中经过稍微调整的bottleneck block 类。是fishnet.py文件Fish类中构建fishnet模型的重要组件。
- net_factory.py: 整个这三个文件的对外接口。其中包含三个函数:
1、def fishnet99(**kwargs):
2、def fishnet150(**kwargs):
3、def fishnet201(**kwargs):
调用不同的函数可以返回不同模型大小的fishnet模型。
1) fishnet.py:
from __future__ import division
import torch
import math
from .fish_block import *
__all__ = ['fish']
class Fish(nn.Module):
def __init__(self, block, num_cls=1000, num_down_sample=5, num_up_sample=3, trans_map=(2, 1, 0, 6, 5, 4),
network_planes=None, num_res_blks=None, num_trans_blks=None):
super(Fish, self).__init__()
self.block = block
self.trans_map = trans_map
self.upsample = nn.Upsample(scale_factor=2)
self.down_sample = nn.MaxPool2d(2, stride=2)
self.num_cls = num_cls
self.num_down = num_down_sample
self.num_up = num_up_sample
self.network_planes = network_planes[1:]
self.depth = len(self.network_planes)
self.num_trans_blks = num_trans_blks
self.num_res_blks = num_res_blks
self.fish = self._make_fish(network_planes[0])
def _make_score(self, in_ch, out_ch=1000, has_pool=False):
bn = nn.BatchNorm2d(in_ch)
relu = nn.ReLU(inplace=True)
conv_trans = nn.Conv2d(in_ch, in_ch // 2, kernel_size=1, bias=False)
bn_out = nn.BatchNorm2d(in_ch // 2)
conv = nn.Sequential(bn, relu, conv_trans, bn_out, relu)
if has_pool:
fc = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True))
else:
fc = nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True)
return [conv, fc]
def _make_se_block(self, in_ch, out_ch):
bn = nn.BatchNorm2d(in_ch)
sq_conv = nn.Conv2d(in_ch, out_ch // 16, kernel_size=1)
ex_conv = nn.Conv2d(out_ch // 16, out_ch, kernel_size=1)
return nn.Sequential(bn,
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
sq_conv,
nn.ReLU(inplace=True),
ex_conv,
nn.Sigmoid())
def _make_residual_block(self, inplanes, outplanes, nstage, is_up=False, k=1, dilation=1):
layers = []
if is_up:
layers.append(self.block(inplanes, outplanes, mode='UP', dilation=dilation, k=k))
else:
layers.append(self.block(inplanes, outplanes, stride=1))
for i in range(1, nstage):
layers.append(self.block(outplanes, outplanes, stride=1, dilation=dilation))
return nn.Sequential(*layers)
def _make_stage(self, is_down_sample, inplanes, outplanes, n_blk, has_trans=True,
has_score=False, trans_planes=0, no_sampling=False, num_trans=2, **kwargs):
sample_block = []
if has_score:
sample_block.extend(self._make_score(outplanes, outplanes * 2, has_pool=False))
if no_sampling or is_down_sample:
res_block = self._make_residual_block(inplanes, outplanes, n_blk, **kwargs)
else:
res_block = self._make_residual_block(inplanes, outplanes, n_blk, is_up=True, **kwargs)
sample_block.append(res_block)
if has_trans:
trans_in_planes = self.in_planes if trans_planes == 0 else trans_planes
sample_block.append(self._make_residual_block(trans_in_planes, trans_in_planes, num_trans))
if not no_sampling and is_down_sample:
sample_block.append(self.down_sample)
elif not no_sampling: # Up-Sample
sample_block.append(self.upsample)
return nn.ModuleList(sample_block)
def _make_fish(self, in_planes):
def get_trans_planes(index):
map_id = self.trans_map[index-self.num_down-1] - 1
p = in_planes if map_id == -1 else cated_planes[map_id]
return p
def get_trans_blk(index):
return self.num_trans_blks[index-self.num_down-1]
def get_cur_planes(index):
return self.network_planes[index]
def get_blk_num(index):
return self.num_res_blks[index]
cated_planes, fish = [in_planes] * self.depth, []
for i in range(self.depth):
# even num for down-sample, odd for up-sample
is_down, has_trans, no_sampling = i not in range(self.num_down, self.num_down+self.num_up+1),\
i > self.num_down, i == self.num_down
# is_down, has_trans, no_sampling:True False False; True False False; True False False; False False True
# False True False; False True False; False True False; True True False;True True False; True True False
cur_planes, trans_planes, cur_blocks, num_trans =\
get_cur_planes(i), get_trans_planes(i), get_blk_num(i), get_trans_blk(i)
# cur_planes, trans_planes, cur_blocks, num_trans:128 64 2 1;256 64 2 1; 512 64 6 1; 512 64 2 4
# 512 256 1 1; 384 128 1 1; 256 64 1 1; 320 512 1 1;832 768 2 1; 1600 512 2 4
stg_args = [is_down, cated_planes[i - 1], cur_planes, cur_blocks]
# inplanes:64,128,256,512,1024,512,768,512,320,832,1600
if is_down or no_sampling:
k, dilation = 1, 1
else:
k, dilation = cated_planes[i - 1] // cur_planes, 2 ** (i-self.num_down-1)
sample_block = self._make_stage(*stg_args, has_trans=has_trans, trans_planes=trans_planes,
has_score=(i==self.num_down), num_trans=num_trans, k=k, dilation=dilation,
no_sampling=no_sampling)
if i == self.depth - 1:
sample_block.extend(self._make_score(cur_planes + trans_planes, out_ch=self.num_cls, has_pool=True))
elif i == self.num_down:
sample_block.append(nn.Sequential(self._make_se_block(cur_planes*2, cur_planes)))
if i == self.num_down-1:
cated_planes[i] = cur_planes * 2
elif has_trans:
cated_planes[i] = cur_planes + trans_planes
else:
cated_planes[i] = cur_planes
fish.append(sample_block)
return nn.ModuleList(fish)
def _fish_forward(self, all_feat):
def _concat(a, b):
return torch.cat([a, b], dim=1)
def stage_factory(*blks):
def stage_forward(*inputs):
if stg_id < self.num_down: # tail
tail_blk = nn.Sequential(*blks[:2])
# print(stg_id)
# print(tail_blk)
return tail_blk(*inputs)
elif stg_id == self.num_down:
score_blks = nn.Sequential(*blks[:2])
score_feat = score_blks(inputs[0])
att_feat = blks[3](score_feat)
return blks[2](score_feat) * att_feat + att_feat
else: # refine
feat_trunk = blks[2](blks[0](inputs[0]))
feat_branch = blks[1](inputs[1])
return _concat(feat_trunk, feat_branch)
return stage_forward
stg_id = 0
# tail:
while stg_id < self.depth:
stg_blk = stage_factory(*self.fish[stg_id])
if stg_id <= self.num_down:
in_feat = [all_feat[stg_id]]
else:
trans_id = self.trans_map[stg_id-self.num_down-1]
in_feat = [all_feat[stg_id], all_feat[trans_id]]
all_feat[stg_id + 1] = stg_blk(*in_feat)
stg_id += 1
# loop exit
if stg_id == self.depth:
score_feat = self.fish[self.depth-1][-2](all_feat[-1])
score = self.fish[self.depth-1][-1](score_feat)
for fea in all_feat:
print(fea.shape)
return score
def forward(self, x):
all_feat = [None] * (self.depth + 1)
all_feat[0] = x
return self._fish_forward(all_feat)
class FishNet(nn.Module):
def __init__(self, block, **kwargs):
super(FishNet, self).__init__()
inplanes = kwargs['network_planes'][0]
# resolution: 224x224
self.conv1 = self._conv_bn_relu(3, inplanes // 2, stride=2)
self.conv2 = self._conv_bn_relu(inplanes // 2, inplanes // 2)
self.conv3 = self._conv_bn_relu(inplanes // 2, inplanes)
self.pool1 = nn.MaxPool2d(3, padding=1, stride=2)
# construct fish, resolution 56x56
self.fish = Fish(block, **kwargs)
self._init_weights()
def _conv_bn_relu(self, in_ch, out_ch, stride=1):
return nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True))
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.pool1(x)
# x.Size([1, 64, 56, 56])
score = self.fish(x)
# 1*1 output
out = score.view(x.size(0), -1)
return out
def fish(**kwargs):
return FishNet(Bottleneck, **kwargs)
2) fish_block.py:
import torch.nn as nn
class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, mode='NORM', k=1, dilation=1):
"""
Pre-act residual block, the middle transformations are bottle-necked
:param inplanes:
:param planes:
:param stride:
:param downsample:
:param mode: NORM | UP
:param k: times of additive
"""
super(Bottleneck, self).__init__()
self.mode = mode
self.relu = nn.ReLU(inplace=True)
self.k = k
btnk_ch = planes // 4
self.bn1 = nn.BatchNorm2d(inplanes)
self.conv1 = nn.Conv2d(inplanes, btnk_ch, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(btnk_ch)
self.conv2 = nn.Conv2d(btnk_ch, btnk_ch, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, bias=False)
self.bn3 = nn.BatchNorm2d(btnk_ch)
self.conv3 = nn.Conv2d(btnk_ch, planes, kernel_size=1, bias=False)
if mode == 'UP':
self.shortcut = None
elif inplanes != planes or stride > 1:
self.shortcut = nn.Sequential(
nn.BatchNorm2d(inplanes),
self.relu,
nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
)
else:
self.shortcut = None
def _pre_act_forward(self, x):
residual = x
out = self.bn1(x)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.mode == 'UP':
residual = self.squeeze_idt(x)
elif self.shortcut is not None:
residual = self.shortcut(residual)
out += residual
return out
def squeeze_idt(self, idt):
n, c, h, w = idt.size()
return idt.view(n, c // self.k, self.k, h, w).sum(2)
def forward(self, x):
out = self._pre_act_forward(x)
return out
3) fish_block.py:
from models.fishnet import fish
import torch
def fishnet99(**kwargs):
"""
:return:
"""
net_cfg = {
# input size: [224, 56, 28, 14 | 7, 7, 14, 28 | 56, 28, 14]
# output size: [56, 28, 14, 7 | 7, 14, 28, 56 | 28, 14, 7]
# | | | | | | | | | | |
'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600],
'num_res_blks': [2, 2, 6, 2, 1, 1, 1, 1, 2, 2],
'num_trans_blks': [1, 1, 1, 1, 1, 4],
'num_cls': 1000,
'num_down_sample': 3,
'num_up_sample': 3,
}
cfg = {**net_cfg, **kwargs}
return fish(**cfg)
def fishnet150(**kwargs):
"""
:return:
"""
net_cfg = {
# input size: [224, 56, 28, 14 | 7, 7, 14, 28 | 56, 28, 14]
# output size: [56, 28, 14, 7 | 7, 14, 28, 56 | 28, 14, 7]
# | | | | | | | | | | |
'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600],
'num_res_blks': [2, 4, 8, 4, 2, 2, 2, 2, 2, 4],
'num_trans_blks': [2, 2, 2, 2, 2, 4],
'num_cls': 1000,
'num_down_sample': 3,
'num_up_sample': 3,
}
cfg = {**net_cfg, **kwargs}
return fish(**cfg)
def fishnet201(**kwargs):
"""
:return:
"""
net_cfg = {
# input size: [224, 56, 28, 14 | 7, 7, 14, 28 | 56, 28, 14]
# output size: [56, 28, 14, 7 | 7, 14, 28, 56 | 28, 14, 7]
# | | | | | | | | | | |
'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600],
'num_res_blks': [3, 4, 12, 4, 2, 2, 2, 2, 3, 10],
'num_trans_blks': [2, 2, 2, 2, 2, 9],
'num_cls': 1000,
'num_down_sample': 3,
'num_up_sample': 3,
}
cfg = {**net_cfg, **kwargs}
return fish(**cfg)
四、总结
- 创造性的在类FCN的网络后再次添加了卷积神经网络。这样的处理使得用于目标检测、语义分割等任务的卷积神经网络可以用于图像分类。并且充分利用到了卷积神经网络所提取到的细节信息。
- 在网络中不再使用孤立卷积,使得深层的梯度可以直接传递到浅层。