目录
安装环境
按照https://github.com/HRNet/HRNet-Semantic-Segmentation所给条件安装包。
经过测试以下2种环境可以运行:
(1)Ubuntu16.04,python3,pytorch0.4.1,cuda9
(2)WIN10,python3,pytorch0.4.1,cuda9
注意:from .sync_bn.inplace_abn.bn import InPlaceABNSync(他是多GPU共享BN参数的层,但是为了快速调通网络,舍弃了这种方式)得注释掉,就是这种BN方式我不使用,需要修改seg_hrnet.py中的
# BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') BatchNorm2d = nn.BatchNorm2d
以及注释掉464行的
#elif isinstance(m, InPlaceABNSync): #nn.init.constant_(m.weight, 1) #nn.init.constant_(m.bias, 0)
下载数据
链接:https://pan.baidu.com/s/1dxsVOOZ1RC7c-obM23fHIg
提取码:kmrl
将数据下载好后,再将其中的gtFine和leftImg8bit文件夹放入到源代码中的data目录下,具体data中的目录结构如下所示:
$SEG_ROOT/data
├── cityscapes
│ ├── gtFine
│ │ ├── test
│ │ ├── train
│ │ └── val
│ └── leftImg8bit
│ ├── test
│ ├── train
│ └── val
├── list
│ ├── cityscapes
│ │ ├── test.lst
│ │ ├── trainval.lst
│ │ └── val.lst
│ ├── lip
│ │ ├── testvalList.txt
│ │ ├── trainList.txt
│ │ └── valList.txt
注意:记得删除val.lst当中的,因为这张图片有问题。
leftImg8bit/val/frankfurt/frankfurt_000001_059119_leftImg8bit.png gtFine/val/frankfurt/frankfurt_000001_059119_gtFine_labelIds.png
执行训练文件
python tools/train.py --cfg experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml
记得修改seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml中的文件,经过#号标记的就是修改的部分,我是2块显卡,WORKERS为什么是0参考https://blog.csdn.net/u013066730/article/details/97808471
CUDNN:
BENCHMARK: true
DETERMINISTIC: false
ENABLED: true
GPUS: (0,1) #########################
OUTPUT_DIR: 'output'
LOG_DIR: 'log'
WORKERS: 0 ######################
PRINT_FREQ: 100
DATASET:
DATASET: cityscapes
ROOT: 'data/'
TEST_SET: 'list/cityscapes/val.lst'
TRAIN_SET: 'list/cityscapes/train.lst'
NUM_CLASSES: 19
MODEL:
NAME: seg_hrnet
PRETRAINED: 'pretrained_models/hrnetv2_w48_imagenet_pretrained.pth'
EXTRA:
FINAL_CONV_KERNEL: 1
STAGE2:
NUM_MODULES: 1
NUM_BRANCHES: 2
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
NUM_CHANNELS:
- 48
- 96
FUSE_METHOD: SUM
STAGE3:
NUM_MODULES: 4
NUM_BRANCHES: 3
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
- 4
NUM_CHANNELS:
- 48
- 96
- 192
FUSE_METHOD: SUM
STAGE4:
NUM_MODULES: 3
NUM_BRANCHES: 4
BLOCK: BASIC
NUM_BLOCKS:
- 4
- 4
- 4
- 4
NUM_CHANNELS:
- 48
- 96
- 192
- 384
FUSE_METHOD: SUM
LOSS:
USE_OHEM: false
OHEMTHRES: 0.9
OHEMKEEP: 131072
TRAIN:
IMAGE_SIZE:
- 1024
- 512
BASE_SIZE: 2048
BATCH_SIZE_PER_GPU: 2 ############################
SHUFFLE: true
BEGIN_EPOCH: 0
END_EPOCH: 484
RESUME: true
OPTIMIZER: sgd
LR: 0.01
WD: 0.0005
MOMENTUM: 0.9
NESTEROV: false
FLIP: true
MULTI_SCALE: true
DOWNSAMPLERATE: 1
IGNORE_LABEL: 255
SCALE_FACTOR: 16
TEST:
IMAGE_SIZE:
- 2048
- 1024
BASE_SIZE: 2048
BATCH_SIZE_PER_GPU: 3 ###############################
FLIP_TEST: false
MULTI_SCALE: false
详解训练代码(以下我只讲关键部分,请自行对照源码进行理解)
cfg参数解析
进入tools/train.py这个文件中,开始解析yaml中的参数,使用到了yacs这个库,如果这个库不知道怎么使用,请参考https://blog.csdn.net/u013066730/article/details/97640131。
from config import config
from config import update_config
def parse_args():
parser = argparse.ArgumentParser(description='Train segmentation network')
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
update_config(config, args)
return args
def main():
args = parse_args()
从import中我们可以看出,他是从lib/config文件夹下直接导入的,说明调用的是该文件夹下的__init__文件,查看lib/config/__init__文件可以看出从.default中导入了部分参数和函数,.models据我实测是没什么用的。具体的config和update_config是怎么更新参数的请参考https://blog.csdn.net/u013066730/article/details/97640131。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .default import _C as config
from .default import update_config
from .models import MODEL_EXTRAS
logger相关设置
logger, final_output_dir, tb_log_dir = create_logger(
config, args.cfg, 'train')
logger.info(pprint.pformat(args))
logger.info(config)
writer_dict = {
'writer': SummaryWriter(tb_log_dir),
'train_global_steps': 0,
'valid_global_steps': 0,
}
这里logger就不细讲了,请自行阅读源码。
cuda相关设置
# cudnn related setting
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
gpus = list(config.GPUS)
创建模型(这里稍后有一段详解,就先一笔带过)
# build model
model = eval('models.'+config.MODEL.NAME +
'.get_seg_model')(config)
tensorboards保存模型
dump_input = torch.rand(
(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)
logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
这段我也不清楚为什么要这么做(备份嘛?)
# copy model file
this_dir = os.path.dirname(__file__)
models_dst_dir = os.path.join(final_output_dir, 'models')
if os.path.exists(models_dst_dir):
shutil.rmtree(models_dst_dir)
shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)
具体copy到的路径为HRNet-Semantic-Segmentation-master\output\cityscapes\seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484\models
读取数据
import datasets
train_dataset = eval('datasets.'+config.DATASET.DATASET)(
root=config.DATASET.ROOT,
list_path=config.DATASET.TRAIN_SET,
num_samples=None,
num_classes=config.DATASET.NUM_CLASSES,
multi_scale=config.TRAIN.MULTI_SCALE,
flip=config.TRAIN.FLIP,
ignore_label=config.TRAIN.IGNORE_LABEL,
base_size=config.TRAIN.BASE_SIZE,
crop_size=crop_size,
downsample_rate=config.TRAIN.DOWNSAMPLERATE,
scale_factor=config.TRAIN.SCALE_FACTOR)
trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
shuffle=config.TRAIN.SHUFFLE,
num_workers=config.WORKERS,
pin_memory=True,
drop_last=True)
从导入中可以看出,导入了一个datasets文件夹,也就是导入了datasets文件夹下的__init__文件,所以不妨看看具体的内容
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Ke Sun (sunk@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .cityscapes import Cityscapes as cityscapes
from .lip import LIP as lip
from .pascal_ctx import PASCALContext as pascal_ctx
由于使用的cityscapes数据集,所以就直接看cityscapes文件中的函数即可。另外要说明一下eval('datasets.'+config.DATASET.DATASET)在带入参数后就是eval('datasets.cityscapes'),这句代码的意思就是将字符串'datasets.cityscapes'看成是可执行的代码,也就是直接替换成datasets.cityscapes()这个可以执行的函数。
然后再看datasets\cityscapes.py中的Cityscapes类,就很清晰了。具体类中的实现我就不展开了,和平常自己写读取数据的类基本一样,就多了一点东西而已。这里主要看一下最终输出是什么,具体如下代码所示。
return image.copy(), label.copy(), np.array(size), name
这时候image的形状为(3,512,1024),label的形状为(512,1024) 。
损失函数
# criterion
if config.LOSS.USE_OHEM:
criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
thres=config.LOSS.OHEMTHRES,
min_kept=config.LOSS.OHEMKEEP,
weight=train_dataset.class_weights)
else:
criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
weight=train_dataset.class_weights)
由于配置文件中没有使用OHEM,所以使用的就是普通的权重交叉熵损失。进入到lib\core\criterion.py中进行CrossEntropy类,具体代码如下:
class CrossEntropy(nn.Module):
def __init__(self, ignore_label=-1, weight=None):
super(CrossEntropy, self).__init__()
self.ignore_label = ignore_label
self.criterion = nn.CrossEntropyLoss(weight=weight,
ignore_index=ignore_label)
def forward(self, score, target):
ph, pw = score.size(2), score.size(3) #target shape is [2, 512, 1024], score shape is [2, 19, 128, 256]
h, w = target.size(1), target.size(2)
if ph != h or pw != w:
score = F.upsample(
input=score, size=(h, w), mode='bilinear')
loss = self.criterion(score, target)
return loss
代码中target和score的形状已经注释给出,可以看出,是需要将score进行上采样,这个代码中其实是宽变大4倍,高变大4倍。
加入损失构筑新模型
model = FullModel(model, criterion)
model = nn.DataParallel(model, device_ids=gpus).cuda()
优化器
# optimizer
if config.TRAIN.OPTIMIZER == 'sgd':
optimizer = torch.optim.SGD([{'params':
filter(lambda p: p.requires_grad,
model.parameters()),
'lr': config.TRAIN.LR}],
lr=config.TRAIN.LR,
momentum=config.TRAIN.MOMENTUM,
weight_decay=config.TRAIN.WD,
nesterov=config.TRAIN.NESTEROV,
)
else:
raise ValueError('Only Support SGD optimizer')
迭代计数+模型恢复
epoch_iters = np.int(train_dataset.__len__() /
config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
best_mIoU = 0
last_epoch = 0
if config.TRAIN.RESUME:
model_state_file = os.path.join(final_output_dir,
'checkpoint.pth.tar')
if os.path.isfile(model_state_file):
checkpoint = torch.load(model_state_file)
best_mIoU = checkpoint['best_mIoU']
last_epoch = checkpoint['epoch']
model.module.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
logger.info("=> loaded checkpoint (epoch {})"
.format(checkpoint['epoch']))
start = timeit.default_timer()
end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
num_iters = config.TRAIN.END_EPOCH * epoch_iters
extra_iters = config.TRAIN.EXTRA_EPOCH * epoch_iters
模型训练和验证
for epoch in range(last_epoch, end_epoch):
if epoch >= config.TRAIN.END_EPOCH:
train(config, epoch-config.TRAIN.END_EPOCH,
config.TRAIN.EXTRA_EPOCH, epoch_iters,
config.TRAIN.EXTRA_LR, extra_iters,
extra_trainloader, optimizer, model, writer_dict)
else:
train(config, epoch, config.TRAIN.END_EPOCH,
epoch_iters, config.TRAIN.LR, num_iters,
trainloader, optimizer, model, writer_dict)
logger.info('=> saving checkpoint to {}'.format(
final_output_dir + 'checkpoint.pth.tar'))
torch.save({
'epoch': epoch+1,
'best_mIoU': best_mIoU,
'state_dict': model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}, os.path.join(final_output_dir,'checkpoint.pth.tar'))
valid_loss, mean_IoU, IoU_array = validate(
config, testloader, model, writer_dict)
if mean_IoU > best_mIoU:
best_mIoU = mean_IoU
torch.save(model.module.state_dict(),
os.path.join(final_output_dir, 'best.pth'))
msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(
valid_loss, mean_IoU, best_mIoU)
logging.info(msg)
logging.info(IoU_array)
这里也没什么好讲解的,具体的train函数和val函数在core/function.py文件中,具体来看下train函数吧。
def train(config, epoch, num_epoch, epoch_iters, base_lr,
num_iters, trainloader, optimizer, model, writer_dict):
# Training
model.train()
batch_time = AverageMeter()
ave_loss = AverageMeter()
tic = time.time()
cur_iters = epoch*epoch_iters
writer = writer_dict['writer']
global_steps = writer_dict['train_global_steps']
for i_iter, batch in enumerate(trainloader, 0):
images, labels, _, _ = batch #image shape is [4,3,512,1024] and label shape is [4,512,1024]
# print(images.size())
# print(labels.size())
labels = labels.long().cuda()
losses, _ = model(images, labels)
loss = losses.mean()
model.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - tic)
tic = time.time()
# update average loss
ave_loss.update(loss.item())
lr = adjust_learning_rate(optimizer,
base_lr,
num_iters,
i_iter+cur_iters)
if i_iter % config.PRINT_FREQ == 0:
msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
'lr: {:.6f}, Loss: {:.6f}' .format(
epoch, num_epoch, i_iter, epoch_iters,
batch_time.average(), lr, ave_loss.average())
logging.info(msg)
writer.add_scalar('train_loss', ave_loss.average(), global_steps)
writer_dict['train_global_steps'] = global_steps + 1
train中,这时得到的数据形状为image shape is [4,3,512,1024] and label shape is [4,512,1024]。batchsize是2,但是由于是2快卡,所以最终取到的数据的batchsize为4。
网络结构详细讲解
首先进入到models\seg_hrnet.py中的
def get_seg_model(cfg, **kwargs):
model = HighResolutionNet(cfg, **kwargs)
model.init_weights(cfg.MODEL.PRETRAINED)
return model
从代码中可以看出,一个是HighResolutionNet的网络结构搭建,一个是init_weights的模型初始化。重点来看HighResolutionNet。
下面的内容请对照源代码,我只列出部分代码出来讲解。
接下来是HightResolutionNet中的函数,
class HighResolutionNet(nn.Module):
def __init__(self, config, **kwargs):
extra = config.MODEL.EXTRA
super(HighResolutionNet, self).__init__()
def _make_transition_layer(
self, num_channels_pre_layer, num_channels_cur_layer)
def _make_layer(self, block, inplanes, planes, blocks, stride=1)
def _make_stage(self, layer_config, num_inchannels,
multi_scale_output=True)
def forward(self, x): # x的shape是(2,3,512,1024)
def init_weights(self, pretrained='',)
首先进入到forward函数中,这时输入的x的形状为(2,3,512,1024)。经过一次conv1和conv2的卷积,他们都是卷积核为3,步长为2,所以直接将图像进行了缩小,宽高都变为原来的1/4,所以此时形状为(2,64,128,256)。随后经过4个残差单元,也就是self.layer1,之后形状变为(2,256,128,256)。
def forward(self, x): # x的shape是(2,3,512,1024)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x) # x的shape是(2,256,128,256)
具体运行的步骤就如下图所示
接下来继续看代码,这部分增加了分支,并在各分支内进行卷积操作。
x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
从self.transition1跳转到self.transition1 = self._make_transition_layer([256], num_channels),然后继续跳转到_make_transition_layer函数,
i | num_branches_pre | operation |
0 | 1 | conv(256,48) k=3,s=1 |
1 | 1 | conv(256,96) k=3,s=2 |
def _make_transition_layer(
self, num_channels_pre_layer, num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
# num_branches_cur等于2,num_branches_pre等于1
# i=0,进入到判断if num_channels_cur_layer[i] != num_channels_pre_layer[i];
# 又由于这两个通道数不等,进入该判断执行代码。
# i=1,进入for j in range(i+1-num_branches_pre)这个循环中,然后得到了一个小的特征图谱
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(nn.Sequential(
nn.Conv2d(num_channels_pre_layer[i],
num_channels_cur_layer[i],
3,
1,
1,
bias=False),
BatchNorm2d(
num_channels_cur_layer[i], momentum=BN_MOMENTUM),
nn.ReLU(inplace=False)))
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i+1-num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[i] \
if j == i-num_branches_pre else inchannels
conv3x3s.append(nn.Sequential(
nn.Conv2d(
inchannels, outchannels, 3, 2, 1, bias=False),
BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=False)))
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
从中可以看出,这个是经历了2个卷积,最终由x变成了x_list,当中由2个分支,他们的形状分别为:(2,48,128,256),(2,96,64,128)。
继续向下看,到了代码
y_list = self.stage2(x_list) # fuse得到2个分支
然后就是具体的_make_stage函数,这里主要是参数和具体使用几次大模块。接下来具体介绍下_make_stage函数中的HighResolutionModule函数。
class HighResolutionModule(nn.Module):
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
num_channels, fuse_method, multi_scale_output=True):
super(HighResolutionModule, self).__init__()
self._check_branches(
num_branches, blocks, num_blocks, num_inchannels, num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
self.num_branches = num_branches
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(
num_branches, blocks, num_blocks, num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=False)
def _check_branches(self, num_branches, blocks, num_blocks,
num_inchannels, num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
logger.error(error_msg)
raise ValueError(error_msg)
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
stride=1):
downsample = None
if stride != 1 or \
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1, stride=stride, bias=False),
BatchNorm2d(num_channels[branch_index] * block.expansion,
momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index], stride, downsample))
self.num_inchannels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(block(self.num_inchannels[branch_index],
num_channels[branch_index]))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1): # j表示输入分支,i表示输出分支
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_inchannels[i],
1,
1,
0,
bias=False),
BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i-j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False),
BatchNorm2d(num_outchannels_conv3x3,
momentum=BN_MOMENTUM)))
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(nn.Sequential(
nn.Conv2d(num_inchannels[j],
num_outchannels_conv3x3,
3, 2, 1, bias=False),
BatchNorm2d(num_outchannels_conv3x3,
momentum=BN_MOMENTUM),
nn.ReLU(inplace=False)))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
elif j > i:
width_output = x[i].shape[-1]
height_output = x[i].shape[-2]
y = y + F.interpolate(
self.fuse_layers[i][j](x[j]),
size=[height_output, width_output],
mode='bilinear')
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
依旧从forward中开始看起,首先就是2个分支自己做卷积,主要就是_make_branches,其实每个分支分别做了4个残差单元卷积。具体如下图所示:
这之后就是融合了,主要就是fuse_layers函数,其中j表示输入分支,i表示输出分支。
i | j | operation |
0 | 0 | None |
0 | 1 | conv(96,48) s=1,k=1 还有一点要注意的就是fuse_layers没有进行尺度变化,具体尺度变化在forword中,这里是上采样 |
1 | 0 | conv(48,96) s=2,k=3 |
1 | 1 | None |
当再回到HighResolutionModule类中的forward中时,进行了融合,也就是sum。具体代码如下:
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
elif j > i:
width_output = x[i].shape[-1]
height_output = x[i].shape[-2]
y = y + F.interpolate(
self.fuse_layers[i][j](x[j]),
size=[height_output, width_output],
mode='bilinear')
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
根据上面的表格,再结合上面的代码,可以看出j表示输入分支,i表示输出分支,多个j融合得到一个i。这步操作如图所示:
回到HighResolutionNet类的forward函数,可以得到y_list的形状为(2,48,128,256),(2,96,64,128)。
接下来进入到transition2,具体代码如下
x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
if i < self.stage2_cfg['NUM_BRANCHES']:
x_list.append(self.transition2[i](y_list[i]))
else:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
这个就相当于将得到的2个分支变成3个分支。具体会得到如下表格:
i(0,1,2) | num_branches_pre | operation |
0 | 2 | None |
1 | 2 | None |
2 | 2 | conv(96,192),k=3,s=2 |
这时候我们得回到HighResolutionNet类中的forward函数,具体操作可以参照下面的图像,下面粉色框中的黄色和橘黄色表明self.transition2函数时直接copy的这两个特征图,没有额外的操作,粉色的小方块表明他是由橘色的小方块经过conv(96,192),k=3,s=2这样的操作得到的,和原文中稍微有点区别。
接着就是stages的分支内部卷积和融合,这里就不细讲了,其实和上面的流程是一样的。
y_list = self.stage3(x_list)
但这里有几点需要注意的,具体的不同我对照下面的图像进行讲解
这个蓝色的框就是这个stage3所完成的所有操作,这个蓝色框内的操作一共被进行了4次。
我介绍其中一次的操作,一次操作一共包括2个步骤,一个_make_branches,一个_make_fuse_layers,具体的情况我就不一一列举了,只需要自己代入验证即可。
接下来介绍transition3和stage4,我就一起介绍了,反正和前面差不多。依旧对着图片介绍。
黑色框就表示在做transition3,这是在4次框内操作结束后选取最后一次操作的结果进行的,这个黑色小块其实就是最后一次紫色小块进行一次conv(192,384)k=3,s=2的卷积操作完成的。
最后就是将这四个分支进行结合,具体结合就是小的放到大的尺寸,然后叠加。
# Upsampling
x0_h, x0_w = x[0].size(2), x[0].size(3)
x1 = F.upsample(x[1], size=(x0_h, x0_w), mode='bilinear')
x2 = F.upsample(x[2], size=(x0_h, x0_w), mode='bilinear')
x3 = F.upsample(x[3], size=(x0_h, x0_w), mode='bilinear')
x = torch.cat([x[0], x1, x2, x3], 1) # 形状为(1,15C,h,w),实际(2,720,128,256)
x = self.last_layer(x) # 实际(2,19,128,256)
到这里,网络结构基本就结束了。