【HRnet】High-Resolution Representations for Labeling Pixels and Regions论文及代码理解

目录

 

安装环境

下载数据

执行训练文件

详解训练代码(以下我只讲关键部分,请自行对照源码进行理解)

cfg参数解析

logger相关设置

cuda相关设置

创建模型(这里稍后有一段详解,就先一笔带过)

 tensorboards保存模型 

这段我也不清楚为什么要这么做(备份嘛?)

读取数据

损失函数

加入损失构筑新模型

优化器

迭代计数+模型恢复

模型训练和验证

网络结构详细讲解


安装环境

按照https://github.com/HRNet/HRNet-Semantic-Segmentation所给条件安装包。

经过测试以下2种环境可以运行:

(1)Ubuntu16.04,python3,pytorch0.4.1,cuda9

(2)WIN10,python3,pytorch0.4.1,cuda9

注意:from .sync_bn.inplace_abn.bn import InPlaceABNSync(他是多GPU共享BN参数的层,但是为了快速调通网络,舍弃了这种方式)得注释掉,就是这种BN方式我不使用,需要修改seg_hrnet.py中的

# BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
BatchNorm2d = nn.BatchNorm2d

以及注释掉464行的

            #elif isinstance(m, InPlaceABNSync):
                #nn.init.constant_(m.weight, 1)
                #nn.init.constant_(m.bias, 0)

下载数据

链接:https://pan.baidu.com/s/1dxsVOOZ1RC7c-obM23fHIg 
提取码:kmrl 

将数据下载好后,再将其中的gtFine和leftImg8bit文件夹放入到源代码中的data目录下,具体data中的目录结构如下所示:

$SEG_ROOT/data
├── cityscapes
│   ├── gtFine
│   │   ├── test
│   │   ├── train
│   │   └── val
│   └── leftImg8bit
│       ├── test
│       ├── train
│       └── val
├── list
│   ├── cityscapes
│   │   ├── test.lst
│   │   ├── trainval.lst
│   │   └── val.lst
│   ├── lip
│   │   ├── testvalList.txt
│   │   ├── trainList.txt
│   │   └── valList.txt

注意:记得删除val.lst当中的,因为这张图片有问题。

leftImg8bit/val/frankfurt/frankfurt_000001_059119_leftImg8bit.png   gtFine/val/frankfurt/frankfurt_000001_059119_gtFine_labelIds.png

执行训练文件

python tools/train.py --cfg experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml

记得修改seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml中的文件,经过#号标记的就是修改的部分,我是2块显卡,WORKERS为什么是0参考https://blog.csdn.net/u013066730/article/details/97808471

CUDNN:
  BENCHMARK: true
  DETERMINISTIC: false
  ENABLED: true
GPUS: (0,1) #########################
OUTPUT_DIR: 'output'
LOG_DIR: 'log'
WORKERS: 0  ######################
PRINT_FREQ: 100

DATASET:
  DATASET: cityscapes
  ROOT: 'data/'
  TEST_SET: 'list/cityscapes/val.lst'
  TRAIN_SET: 'list/cityscapes/train.lst'
  NUM_CLASSES: 19
MODEL:
  NAME: seg_hrnet
  PRETRAINED: 'pretrained_models/hrnetv2_w48_imagenet_pretrained.pth'
  EXTRA:
    FINAL_CONV_KERNEL: 1
    STAGE2:
      NUM_MODULES: 1
      NUM_BRANCHES: 2
      BLOCK: BASIC
      NUM_BLOCKS:
      - 4
      - 4
      NUM_CHANNELS:
      - 48
      - 96
      FUSE_METHOD: SUM
    STAGE3:
      NUM_MODULES: 4
      NUM_BRANCHES: 3
      BLOCK: BASIC
      NUM_BLOCKS:
      - 4
      - 4
      - 4
      NUM_CHANNELS:
      - 48
      - 96
      - 192
      FUSE_METHOD: SUM
    STAGE4:
      NUM_MODULES: 3
      NUM_BRANCHES: 4
      BLOCK: BASIC
      NUM_BLOCKS:
      - 4
      - 4
      - 4
      - 4
      NUM_CHANNELS:
      - 48
      - 96
      - 192
      - 384
      FUSE_METHOD: SUM
LOSS:
  USE_OHEM: false
  OHEMTHRES: 0.9
  OHEMKEEP: 131072
TRAIN:
  IMAGE_SIZE:
  - 1024
  - 512
  BASE_SIZE: 2048
  BATCH_SIZE_PER_GPU: 2 ############################
  SHUFFLE: true
  BEGIN_EPOCH: 0
  END_EPOCH: 484
  RESUME: true
  OPTIMIZER: sgd
  LR: 0.01
  WD: 0.0005
  MOMENTUM: 0.9
  NESTEROV: false
  FLIP: true
  MULTI_SCALE: true
  DOWNSAMPLERATE: 1
  IGNORE_LABEL: 255
  SCALE_FACTOR: 16
TEST:
  IMAGE_SIZE:
  - 2048
  - 1024
  BASE_SIZE: 2048
  BATCH_SIZE_PER_GPU: 3 ###############################
  FLIP_TEST: false
  MULTI_SCALE: false

详解训练代码(以下我只讲关键部分,请自行对照源码进行理解)

cfg参数解析

进入tools/train.py这个文件中,开始解析yaml中的参数,使用到了yacs这个库,如果这个库不知道怎么使用,请参考https://blog.csdn.net/u013066730/article/details/97640131

from config import config
from config import update_config

def parse_args():
    parser = argparse.ArgumentParser(description='Train segmentation network')
    
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    args = parser.parse_args()
    update_config(config, args)

    return args

def main():
    args = parse_args()

从import中我们可以看出,他是从lib/config文件夹下直接导入的,说明调用的是该文件夹下的__init__文件,查看lib/config/__init__文件可以看出从.default中导入了部分参数和函数,.models据我实测是没什么用的。具体的config和update_config是怎么更新参数的请参考https://blog.csdn.net/u013066730/article/details/97640131

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .default import _C as config
from .default import update_config
from .models import MODEL_EXTRAS

logger相关设置

logger, final_output_dir, tb_log_dir = create_logger(
        config, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(config)

    writer_dict = {
        'writer': SummaryWriter(tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

这里logger就不细讲了,请自行阅读源码。

cuda相关设置

# cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    cudnn.deterministic = config.CUDNN.DETERMINISTIC
    cudnn.enabled = config.CUDNN.ENABLED
    gpus = list(config.GPUS)

创建模型(这里稍后有一段详解,就先一笔带过)

# build model
    model = eval('models.'+config.MODEL.NAME +
                 '.get_seg_model')(config)

 tensorboards保存模型 

    dump_input = torch.rand(
        (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
    )
    logger.info(get_model_summary(model.cuda(), dump_input.cuda()))

这段我也不清楚为什么要这么做(备份嘛?)

# copy model file
    this_dir = os.path.dirname(__file__)
    models_dst_dir = os.path.join(final_output_dir, 'models')
    if os.path.exists(models_dst_dir):
        shutil.rmtree(models_dst_dir)
    shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir)

具体copy到的路径为HRNet-Semantic-Segmentation-master\output\cityscapes\seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484\models

读取数据

import datasets


train_dataset = eval('datasets.'+config.DATASET.DATASET)(
                        root=config.DATASET.ROOT,
                        list_path=config.DATASET.TRAIN_SET,
                        num_samples=None,
                        num_classes=config.DATASET.NUM_CLASSES,
                        multi_scale=config.TRAIN.MULTI_SCALE,
                        flip=config.TRAIN.FLIP,
                        ignore_label=config.TRAIN.IGNORE_LABEL,
                        base_size=config.TRAIN.BASE_SIZE,
                        crop_size=crop_size,
                        downsample_rate=config.TRAIN.DOWNSAMPLERATE,
                        scale_factor=config.TRAIN.SCALE_FACTOR)

trainloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus),
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True)

从导入中可以看出,导入了一个datasets文件夹,也就是导入了datasets文件夹下的__init__文件,所以不妨看看具体的内容

# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Ke Sun (sunk@mail.ustc.edu.cn)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from .cityscapes import Cityscapes as cityscapes
from .lip import LIP as lip
from .pascal_ctx import PASCALContext as pascal_ctx

由于使用的cityscapes数据集,所以就直接看cityscapes文件中的函数即可。另外要说明一下eval('datasets.'+config.DATASET.DATASET)在带入参数后就是eval('datasets.cityscapes'),这句代码的意思就是将字符串'datasets.cityscapes'看成是可执行的代码,也就是直接替换成datasets.cityscapes()这个可以执行的函数。

然后再看datasets\cityscapes.py中的Cityscapes类,就很清晰了。具体类中的实现我就不展开了,和平常自己写读取数据的类基本一样,就多了一点东西而已。这里主要看一下最终输出是什么,具体如下代码所示。

return image.copy(), label.copy(), np.array(size), name

这时候image的形状为(3,512,1024),label的形状为(512,1024) 。

损失函数

# criterion
    if config.LOSS.USE_OHEM:
        criterion = OhemCrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                     thres=config.LOSS.OHEMTHRES,
                                     min_kept=config.LOSS.OHEMKEEP,
                                     weight=train_dataset.class_weights)
    else:
        criterion = CrossEntropy(ignore_label=config.TRAIN.IGNORE_LABEL,
                                 weight=train_dataset.class_weights)

由于配置文件中没有使用OHEM,所以使用的就是普通的权重交叉熵损失。进入到lib\core\criterion.py中进行CrossEntropy类,具体代码如下:

class CrossEntropy(nn.Module):
    def __init__(self, ignore_label=-1, weight=None):
        super(CrossEntropy, self).__init__()
        self.ignore_label = ignore_label
        self.criterion = nn.CrossEntropyLoss(weight=weight, 
                                             ignore_index=ignore_label)

    def forward(self, score, target):
        ph, pw = score.size(2), score.size(3)  #target shape is [2, 512, 1024], score shape is [2, 19, 128, 256]
        h, w = target.size(1), target.size(2)
        if ph != h or pw != w:
            score = F.upsample(
                    input=score, size=(h, w), mode='bilinear')

        loss = self.criterion(score, target)

        return loss

 代码中target和score的形状已经注释给出,可以看出,是需要将score进行上采样,这个代码中其实是宽变大4倍,高变大4倍。

加入损失构筑新模型

model = FullModel(model, criterion)
model = nn.DataParallel(model, device_ids=gpus).cuda()

优化器

# optimizer
    if config.TRAIN.OPTIMIZER == 'sgd':
        optimizer = torch.optim.SGD([{'params':
                                  filter(lambda p: p.requires_grad,
                                         model.parameters()),
                                  'lr': config.TRAIN.LR}],
                                lr=config.TRAIN.LR,
                                momentum=config.TRAIN.MOMENTUM,
                                weight_decay=config.TRAIN.WD,
                                nesterov=config.TRAIN.NESTEROV,
                                )
    else:
        raise ValueError('Only Support SGD optimizer')

迭代计数+模型恢复

    epoch_iters = np.int(train_dataset.__len__() / 
                        config.TRAIN.BATCH_SIZE_PER_GPU / len(gpus))
    best_mIoU = 0
    last_epoch = 0
    if config.TRAIN.RESUME:
        model_state_file = os.path.join(final_output_dir,
                                        'checkpoint.pth.tar')
        if os.path.isfile(model_state_file):
            checkpoint = torch.load(model_state_file)
            best_mIoU = checkpoint['best_mIoU']
            last_epoch = checkpoint['epoch']
            model.module.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint (epoch {})"
                        .format(checkpoint['epoch']))

    start = timeit.default_timer()
    end_epoch = config.TRAIN.END_EPOCH + config.TRAIN.EXTRA_EPOCH
    num_iters = config.TRAIN.END_EPOCH * epoch_iters
    extra_iters = config.TRAIN.EXTRA_EPOCH * epoch_iters

模型训练和验证

    for epoch in range(last_epoch, end_epoch):
        if epoch >= config.TRAIN.END_EPOCH:
            train(config, epoch-config.TRAIN.END_EPOCH, 
                  config.TRAIN.EXTRA_EPOCH, epoch_iters, 
                  config.TRAIN.EXTRA_LR, extra_iters, 
                  extra_trainloader, optimizer, model, writer_dict)
        else:
            train(config, epoch, config.TRAIN.END_EPOCH, 
                  epoch_iters, config.TRAIN.LR, num_iters,
                  trainloader, optimizer, model, writer_dict)

        logger.info('=> saving checkpoint to {}'.format(
            final_output_dir + 'checkpoint.pth.tar'))
        torch.save({
            'epoch': epoch+1,
            'best_mIoU': best_mIoU,
            'state_dict': model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, os.path.join(final_output_dir,'checkpoint.pth.tar'))
        valid_loss, mean_IoU, IoU_array = validate(
                        config, testloader, model, writer_dict)
        if mean_IoU > best_mIoU:
            best_mIoU = mean_IoU
            torch.save(model.module.state_dict(),
                       os.path.join(final_output_dir, 'best.pth'))
        msg = 'Loss: {:.3f}, MeanIU: {: 4.4f}, Best_mIoU: {: 4.4f}'.format(
                    valid_loss, mean_IoU, best_mIoU)
        logging.info(msg)
        logging.info(IoU_array)

这里也没什么好讲解的,具体的train函数和val函数在core/function.py文件中,具体来看下train函数吧。

def train(config, epoch, num_epoch, epoch_iters, base_lr, 
        num_iters, trainloader, optimizer, model, writer_dict):
    # Training
    model.train()
    batch_time = AverageMeter()
    ave_loss = AverageMeter()
    tic = time.time()
    cur_iters = epoch*epoch_iters
    writer = writer_dict['writer']
    global_steps = writer_dict['train_global_steps']
    for i_iter, batch in enumerate(trainloader, 0):
        images, labels, _, _ = batch  #image shape is [4,3,512,1024] and label shape is  [4,512,1024]
        # print(images.size())
        # print(labels.size())
        labels = labels.long().cuda()

        losses, _ = model(images, labels)
        loss = losses.mean()

        model.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - tic)
        tic = time.time()

        # update average loss
        ave_loss.update(loss.item())

        lr = adjust_learning_rate(optimizer,
                                  base_lr,
                                  num_iters,
                                  i_iter+cur_iters)

        if i_iter % config.PRINT_FREQ == 0:
            msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
                  'lr: {:.6f}, Loss: {:.6f}' .format(
                      epoch, num_epoch, i_iter, epoch_iters, 
                      batch_time.average(), lr, ave_loss.average())
            logging.info(msg)

    writer.add_scalar('train_loss', ave_loss.average(), global_steps)
    writer_dict['train_global_steps'] = global_steps + 1

train中,这时得到的数据形状为image shape is [4,3,512,1024] and label shape is  [4,512,1024]。batchsize是2,但是由于是2快卡,所以最终取到的数据的batchsize为4。

网络结构详细讲解

首先进入到models\seg_hrnet.py中的

def get_seg_model(cfg, **kwargs):
    model = HighResolutionNet(cfg, **kwargs)
    model.init_weights(cfg.MODEL.PRETRAINED)

    return model

从代码中可以看出,一个是HighResolutionNet的网络结构搭建,一个是init_weights的模型初始化。重点来看HighResolutionNet。

下面的内容请对照源代码,我只列出部分代码出来讲解。

接下来是HightResolutionNet中的函数,

class HighResolutionNet(nn.Module):

    def __init__(self, config, **kwargs):
        extra = config.MODEL.EXTRA
        super(HighResolutionNet, self).__init__()

    def _make_transition_layer(
            self, num_channels_pre_layer, num_channels_cur_layer)
     

    def _make_layer(self, block, inplanes, planes, blocks, stride=1)

    def _make_stage(self, layer_config, num_inchannels,
                    multi_scale_output=True)

    def forward(self, x): # x的shape是(2,3,512,1024)

    def init_weights(self, pretrained='',)

首先进入到forward函数中,这时输入的x的形状为(2,3,512,1024)。经过一次conv1和conv2的卷积,他们都是卷积核为3,步长为2,所以直接将图像进行了缩小,宽高都变为原来的1/4,所以此时形状为(2,64,128,256)。随后经过4个残差单元,也就是self.layer1,之后形状变为(2,256,128,256)。

    def forward(self, x): # x的shape是(2,3,512,1024)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.layer1(x) # x的shape是(2,256,128,256)

具体运行的步骤就如下图所示


接下来继续看代码,这部分增加了分支,并在各分支内进行卷积操作。

        x_list = []
        for i in range(self.stage2_cfg['NUM_BRANCHES']):
            if self.transition1[i] is not None:
                x_list.append(self.transition1[i](x))
            else:
                x_list.append(x)

从self.transition1跳转到self.transition1 = self._make_transition_layer([256], num_channels),然后继续跳转到_make_transition_layer函数,

inum_branches_preoperation
01conv(256,48) k=3,s=1
11conv(256,96) k=3,s=2
    def _make_transition_layer(
            self, num_channels_pre_layer, num_channels_cur_layer):
        num_branches_cur = len(num_channels_cur_layer)
        num_branches_pre = len(num_channels_pre_layer)


        # num_branches_cur等于2,num_branches_pre等于1
        # i=0,进入到判断if num_channels_cur_layer[i] != num_channels_pre_layer[i];
        # 又由于这两个通道数不等,进入该判断执行代码。
        # i=1,进入for j in range(i+1-num_branches_pre)这个循环中,然后得到了一个小的特征图谱
        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(nn.Sequential(
                        nn.Conv2d(num_channels_pre_layer[i],
                                  num_channels_cur_layer[i],
                                  3,
                                  1,
                                  1,
                                  bias=False),
                        BatchNorm2d(
                            num_channels_cur_layer[i], momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=False)))
                else:
                    transition_layers.append(None)
            else:
                conv3x3s = []
                for j in range(i+1-num_branches_pre):
                    inchannels = num_channels_pre_layer[-1]
                    outchannels = num_channels_cur_layer[i] \
                        if j == i-num_branches_pre else inchannels
                    conv3x3s.append(nn.Sequential(
                        nn.Conv2d(
                            inchannels, outchannels, 3, 2, 1, bias=False),
                        BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
                        nn.ReLU(inplace=False)))
                transition_layers.append(nn.Sequential(*conv3x3s))

        return nn.ModuleList(transition_layers)

从中可以看出,这个是经历了2个卷积,最终由x变成了x_list,当中由2个分支,他们的形状分别为:(2,48,128,256),(2,96,64,128)。


 继续向下看,到了代码

y_list = self.stage2(x_list) # fuse得到2个分支

然后就是具体的_make_stage函数,这里主要是参数和具体使用几次大模块。接下来具体介绍下_make_stage函数中的HighResolutionModule函数。

class HighResolutionModule(nn.Module):
    def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
                 num_channels, fuse_method, multi_scale_output=True):
        super(HighResolutionModule, self).__init__()
        self._check_branches(
            num_branches, blocks, num_blocks, num_inchannels, num_channels)

        self.num_inchannels = num_inchannels
        self.fuse_method = fuse_method
        self.num_branches = num_branches

        self.multi_scale_output = multi_scale_output

        self.branches = self._make_branches(
            num_branches, blocks, num_blocks, num_channels)
        self.fuse_layers = self._make_fuse_layers()
        self.relu = nn.ReLU(inplace=False)

    def _check_branches(self, num_branches, blocks, num_blocks,
                        num_inchannels, num_channels):
        if num_branches != len(num_blocks):
            error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
                num_branches, len(num_blocks))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_channels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
                num_branches, len(num_channels))
            logger.error(error_msg)
            raise ValueError(error_msg)

        if num_branches != len(num_inchannels):
            error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
                num_branches, len(num_inchannels))
            logger.error(error_msg)
            raise ValueError(error_msg)

    def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
                         stride=1):
        downsample = None
        if stride != 1 or \
           self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.num_inchannels[branch_index],
                          num_channels[branch_index] * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                BatchNorm2d(num_channels[branch_index] * block.expansion,
                            momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(self.num_inchannels[branch_index],
                            num_channels[branch_index], stride, downsample))
        self.num_inchannels[branch_index] = \
            num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(block(self.num_inchannels[branch_index],
                                num_channels[branch_index]))

        return nn.Sequential(*layers)

    def _make_branches(self, num_branches, block, num_blocks, num_channels):
        branches = []

        for i in range(num_branches):
            branches.append(
                self._make_one_branch(i, block, num_blocks, num_channels))

        return nn.ModuleList(branches)

    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1): # j表示输入分支,i表示输出分支
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(nn.Sequential(
                        nn.Conv2d(num_inchannels[j],
                                  num_inchannels[i],
                                  1,
                                  1,
                                  0,
                                  bias=False),
                        BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    for k in range(i-j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                BatchNorm2d(num_outchannels_conv3x3, 
                                            momentum=BN_MOMENTUM)))
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(nn.Sequential(
                                nn.Conv2d(num_inchannels[j],
                                          num_outchannels_conv3x3,
                                          3, 2, 1, bias=False),
                                BatchNorm2d(num_outchannels_conv3x3,
                                            momentum=BN_MOMENTUM),
                                nn.ReLU(inplace=False)))
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

    def get_num_inchannels(self):
        return self.num_inchannels

    def forward(self, x):
        if self.num_branches == 1:
            return [self.branches[0](x[0])]

        for i in range(self.num_branches):
            x[i] = self.branches[i](x[i])

        x_fuse = []
        for i in range(len(self.fuse_layers)):
            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
            for j in range(1, self.num_branches):
                if i == j:
                    y = y + x[j]
                elif j > i:
                    width_output = x[i].shape[-1]
                    height_output = x[i].shape[-2]
                    y = y + F.interpolate(
                        self.fuse_layers[i][j](x[j]),
                        size=[height_output, width_output],
                        mode='bilinear')
                else:
                    y = y + self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))

        return x_fuse

依旧从forward中开始看起,首先就是2个分支自己做卷积,主要就是_make_branches,其实每个分支分别做了4个残差单元卷积。具体如下图所示:

这之后就是融合了,主要就是fuse_layers函数,其中j表示输入分支,i表示输出分支。

ijoperation
00None
01conv(96,48) s=1,k=1             还有一点要注意的就是fuse_layers没有进行尺度变化,具体尺度变化在forword中,这里是上采样
10conv(48,96) s=2,k=3
11None

 当再回到HighResolutionModule类中的forward中时,进行了融合,也就是sum。具体代码如下:

for i in range(len(self.fuse_layers)):
            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
            for j in range(1, self.num_branches):
                if i == j:
                    y = y + x[j]
                elif j > i:
                    width_output = x[i].shape[-1]
                    height_output = x[i].shape[-2]
                    y = y + F.interpolate(
                        self.fuse_layers[i][j](x[j]),
                        size=[height_output, width_output],
                        mode='bilinear')
                else:
                    y = y + self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))

根据上面的表格,再结合上面的代码,可以看出j表示输入分支,i表示输出分支,多个j融合得到一个i。这步操作如图所示:

 回到HighResolutionNet类的forward函数,可以得到y_list的形状为(2,48,128,256),(2,96,64,128)。


接下来进入到transition2,具体代码如下

        x_list = []
        for i in range(self.stage3_cfg['NUM_BRANCHES']):
            if self.transition2[i] is not None:
                if i < self.stage2_cfg['NUM_BRANCHES']:
                    x_list.append(self.transition2[i](y_list[i]))
                else:
                    x_list.append(self.transition2[i](y_list[-1]))
            else:
                x_list.append(y_list[i])

这个就相当于将得到的2个分支变成3个分支。具体会得到如下表格:

i(0,1,2)num_branches_preoperation
02None
12

None

22conv(96,192),k=3,s=2

这时候我们得回到HighResolutionNet类中的forward函数,具体操作可以参照下面的图像,下面粉色框中的黄色和橘黄色表明self.transition2函数时直接copy的这两个特征图,没有额外的操作,粉色的小方块表明他是由橘色的小方块经过conv(96,192),k=3,s=2这样的操作得到的,和原文中稍微有点区别。

 


接着就是stages的分支内部卷积和融合,这里就不细讲了,其实和上面的流程是一样的。

y_list = self.stage3(x_list)

但这里有几点需要注意的,具体的不同我对照下面的图像进行讲解

这个蓝色的框就是这个stage3所完成的所有操作,这个蓝色框内的操作一共被进行了4次。

我介绍其中一次的操作,一次操作一共包括2个步骤,一个_make_branches,一个_make_fuse_layers,具体的情况我就不一一列举了,只需要自己代入验证即可。


接下来介绍transition3和stage4,我就一起介绍了,反正和前面差不多。依旧对着图片介绍。

黑色框就表示在做transition3,这是在4次框内操作结束后选取最后一次操作的结果进行的,这个黑色小块其实就是最后一次紫色小块进行一次conv(192,384)k=3,s=2的卷积操作完成的。


最后就是将这四个分支进行结合,具体结合就是小的放到大的尺寸,然后叠加。

        # Upsampling
        x0_h, x0_w = x[0].size(2), x[0].size(3)
        x1 = F.upsample(x[1], size=(x0_h, x0_w), mode='bilinear')
        x2 = F.upsample(x[2], size=(x0_h, x0_w), mode='bilinear')
        x3 = F.upsample(x[3], size=(x0_h, x0_w), mode='bilinear')

        x = torch.cat([x[0], x1, x2, x3], 1)  # 形状为(1,15C,h,w),实际(2,720,128,256)

        x = self.last_layer(x) # 实际(2,19,128,256)

到这里,网络结构基本就结束了。

  • 11
    点赞
  • 48
    收藏
    觉得还不错? 一键收藏
  • 25
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值