网络主要搭建代码
三个结构函数
- _make_score:最后一层
- _make_se_block:se通道注意力模块,用adaptiveavgpool池化为1*1 bn->relu->global_pool->conv1x1->relu->conv1x1->sigmoid
- _make_residual_block:中间结构,包括上、下采样两种
def _make_score(self, in_ch, out_ch=10, has_pool=False):
bn = nn.BatchNorm2d(in_ch)
relu = nn.ReLU(inplace=True)
conv_trans = nn.Conv2d(in_ch, in_ch // 2, kernel_size=1, bias=False)
bn_out = nn.BatchNorm2d(in_ch // 2)
conv = nn.Sequential(bn, relu, conv_trans, bn_out, relu)
if has_pool:
fc = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True))
else:
fc = nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True)
return [conv, fc]
def _make_se_block(self, in_ch, out_ch):
bn = nn.BatchNorm2d(in_ch)
sq_conv = nn.Conv2d(in_ch, out_ch // 16, kernel_size=1)
ex_conv = nn.Conv2d(out_ch // 16, out_ch, kernel_size=1)
print("se_block out ch", out_ch)
return nn.Sequential(bn,
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
sq_conv,
nn.ReLU(inplace=True),
ex_conv,
nn.Sigmoid())
def _make_residual_block(self, inplanes, outplanes, nstage, is_up=False, k=1, dilation=1):
layers = []
if is_up:
layers.append(self.block(inplanes, outplanes, mode='UP', dilation=dilation, k=k))
else:
layers.append(self.block(inplanes, outplanes, stride=1))
for i in range(1, nstage):
layers.append(self.block(outplanes, outplanes, stride=1, dilation=dilation))
return nn.Sequential(*layers)
- _make_stage(self, is_down_sample, inplanes, outplanes, n_blk, has_trans=True,has_score=False, trans_planes=0, no_sampling=False, num_trans=2, **kwargs)
根据不同输入参数,在不同阶段(stage = i = 0,1。。。,depth-1)添加上面的三种层中的一种
- 参数:
判断参数\stage | 0~2 | 3 | 4~6 | 7~9 |
is_down | t | f | f | t |
has_trans | f | f | t | t |
no_sample | f | t | f | f |
层 | res(down) | se | res | res、score(i=depth-1) |
- 其中输入、输出通道数由net_factory参数、以及trans_map设定
输入输出'network_planes': [64, 128, 256, 512, 512, 512, 384, 256, 320, 832, 1600], 特征传递序号:trans_map[2,1,0,6,5,4](序号从0开始,2-4,1-5,0-6,5-7,4-8,3-9)
源码:https://github.com/kevin-ssy/FishNet/blob/master/models/fishnet.py
'''
FishNet
Author: Shuyang Sun
'''
from __future__ import division
import torch
import math
from .fish_block import *
__all__ = ['fish']
class Fish(nn.Module):
def __init__(self, block, num_cls=10, num_down_sample=5, num_up_sample=3, trans_map=(2, 1, 0, 6, 5, 4),
network_planes=None, num_res_blks=None, num_trans_blks=None):
super(Fish, self).__init__()
self.block = block
self.trans_map = trans_map
self.upsample = nn.Upsample(scale_factor=2)
self.down_sample = nn.MaxPool2d(2, stride=2)
self.num_cls = num_cls
# print(self.num_cls)
self.num_down = num_down_sample
self.num_up = num_up_sample
self.network_planes = network_planes[1:]
self.depth = len(self.network_planes)
self.num_trans_blks = num_trans_blks
self.num_res_blks = num_res_blks
self.fish = self._make_fish(network_planes[0])
def _make_score(self, in_ch, out_ch=10, has_pool=False):
bn = nn.BatchNorm2d(in_ch)
relu = nn.ReLU(inplace=True)
conv_trans = nn.Conv2d(in_ch, in_ch // 2, kernel_size=1, bias=False)
bn_out = nn.BatchNorm2d(in_ch // 2)
conv = nn.Sequential(bn, relu, conv_trans, bn_out, relu)
if has_pool:
fc = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True))
else:
fc = nn.Conv2d(in_ch // 2, out_ch, kernel_size=1, bias=True)
return [conv, fc]
def _make_se_block(self, in_ch, out_ch):
bn = nn.BatchNorm2d(in_ch)
sq_conv = nn.Conv2d(in_ch, out_ch // 16, kernel_size=1)
ex_conv = nn.Conv2d(out_ch // 16, out_ch, kernel_size=1)
print("se_block out ch", out_ch)
return nn.Sequential(bn,
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
sq_conv,
nn.ReLU(inplace=True),
ex_conv,
nn.Sigmoid())
def _make_residual_block(self, inplanes, outplanes, nstage, is_up=False, k=1, dilation=1):
layers = []
if is_up:
layers.append(self.block(inplanes, outplanes, mode='UP', dilation=dilation, k=k))
else:
layers.append(self.block(inplanes, outplanes, stride=1))
for i in range(1, nstage):
layers.append(self.block(outplanes, outplanes, stride=1, dilation=dilation))
return nn.Sequential(*layers)
def _make_stage(self, is_down_sample, inplanes, outplanes, n_blk, has_trans=True,
has_score=False, trans_planes=0, no_sampling=False, num_trans=2, **kwargs):
sample_block = []
if has_score:
sample_block.extend(self._make_score(outplanes, outplanes * 2, has_pool=False))
if no_sampling or is_down_sample:
res_block = self._make_residual_block(inplanes, outplanes, n_blk, **kwargs)
else:
res_block = self._make_residual_block(inplanes, outplanes, n_blk, is_up=True, **kwargs)
sample_block.append(res_block)
if has_trans:
trans_in_planes = self.in_planes if trans_planes == 0 else trans_planes
sample_block.append(self._make_residual_block(trans_in_planes, trans_in_planes, num_trans))
if not no_sampling and is_down_sample:
sample_block.append(self.down_sample)
elif not no_sampling: # Up-Sample
sample_block.append(self.upsample)
return nn.ModuleList(sample_block)
def _make_fish(self, in_planes):
def get_trans_planes(index):
map_id = self.trans_map[index-self.num_down-1] - 1
p = in_planes if map_id == -1 else cated_planes[map_id]
return p
def get_trans_blk(index):
return self.num_trans_blks[index-self.num_down-1]
def get_cur_planes(index):
return self.network_planes[index]
def get_blk_num(index):
return self.num_res_blks[index]
cated_planes, fish = [in_planes] * self.depth, []
for i in range(self.depth):
# even num for down-sample, odd for up-sample
is_down, has_trans, no_sampling = i not in range(self.num_down, self.num_down+self.num_up+1),\
i > self.num_down, i == self.num_down
cur_planes, trans_planes, cur_blocks, num_trans =\
get_cur_planes(i), get_trans_planes(i), get_blk_num(i), get_trans_blk(i)
stg_args = [is_down, cated_planes[i - 1], cur_planes, cur_blocks]
if is_down or no_sampling:
k, dilation = 1, 1
else:
k, dilation = cated_planes[i - 1] // cur_planes, 2 ** (i-self.num_down-1)
sample_block = self._make_stage(*stg_args, has_trans=has_trans, trans_planes=trans_planes,
has_score=(i==self.num_down), num_trans=num_trans, k=k, dilation=dilation,
no_sampling=no_sampling)
if i == self.depth - 1:
sample_block.extend(self._make_score(cur_planes + trans_planes, out_ch=self.num_cls, has_pool=True))
elif i == self.num_down:
sample_block.append(nn.Sequential(self._make_se_block(cur_planes*2, cur_planes)))
if i == self.num_down-1:
cated_planes[i] = cur_planes * 2
elif has_trans:
cated_planes[i] = cur_planes + trans_planes
else:
cated_planes[i] = cur_planes
fish.append(sample_block)
return nn.ModuleList(fish)
def _fish_forward(self, all_feat):
def _concat(a, b):
return torch.cat([a, b], dim=1)
def stage_factory(*blks):
def stage_forward(*inputs):
if stg_id < self.num_down: # tail
tail_blk = nn.Sequential(*blks[:2])
return tail_blk(*inputs)
elif stg_id == self.num_down:
score_blks = nn.Sequential(*blks[:2])
score_feat = score_blks(inputs[0])
att_feat = blks[3](score_feat)
return blks[2](score_feat) * att_feat + att_feat
else: # refine
feat_trunk = blks[2](blks[0](inputs[0]))
feat_branch = blks[1](inputs[1])
return _concat(feat_trunk, feat_branch)
return stage_forward
stg_id = 0
# tail:
while stg_id < self.depth:
stg_blk = stage_factory(*self.fish[stg_id])
if stg_id <= self.num_down:
in_feat = [all_feat[stg_id]]
else:
trans_id = self.trans_map[stg_id-self.num_down-1]
in_feat = [all_feat[stg_id], all_feat[trans_id]]
all_feat[stg_id + 1] = stg_blk(*in_feat)
stg_id += 1
# loop exit
if stg_id == self.depth:
score_feat = self.fish[self.depth-1][-2](all_feat[-1])
score = self.fish[self.depth-1][-1](score_feat)
return score
def forward(self, x):
all_feat = [None] * (self.depth + 1)
all_feat[0] = x
return self._fish_forward(all_feat)
class FishNet(nn.Module):
def __init__(self, block, **kwargs):
super(FishNet, self).__init__()
print(kwargs)
inplanes = kwargs['network_planes'][0]
# resolution: 224x224
self.conv1 = self._conv_bn_relu(3, inplanes // 2, stride=2)
self.conv2 = self._conv_bn_relu(inplanes // 2, inplanes // 2)
self.conv3 = self._conv_bn_relu(inplanes // 2, inplanes)
self.pool1 = nn.MaxPool2d(3, padding=1, stride=2)
# construct fish, resolution 56x56
self.fish = Fish(block, **kwargs)
self._init_weights()
def _conv_bn_relu(self, in_ch, out_ch, stride=1):
return nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True))
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.pool1(x)
score = self.fish(x)
# 1*1 output
out = score.view(x.size(0), -1)
return out
def fish(**kwargs):
return FishNet(Bottleneck, **kwargs)