HR-Net

1902_HR-Net:

图:

在这里插入图片描述

网络描述:

在这篇论文中,我们主要研究人的姿态问题(human pose estimation problem),着重于输出可靠的高分辨率表征(reliable highresolution representations)。现有的大多数方法都是从高分辨率到低分辨率网络(high-to-low resolution network)产生的低分辨率表征中恢复高分辨率表征。相反,我们提出的网络能在整个过程中都保持高分辨率的表征。

我们从高分辨率子网络(high-resolution subnetwork)作为第一阶段开始,逐步增加高分辨率到低分辨率的子网,形成更多的阶段,并将多分辨率子网并行连接。我们进行了多次多尺度融合multi-scale fusions,使得每一个高分辨率到低分辨率的表征都从其他并行表示中反复接收信息,从而得到丰富的高分辨率表征。因此,预测的关键点热图可能更准确,在空间上也更精确。通过 COCO keypoint detection 数据集和 MPII Human Pose 数据集这两个基准数据集的pose estimation results,我们证明了网络的有效性。此外,我们还展示了网络在 Pose Track 数据集上的姿态跟踪的优越性。

特点,优点:

(1) 我们的方法是并行连接高分辨率到低分辨率的子网,而不是像大多数现有解决方案那样串行接。因此,我们的方法能够保持高分辨率,而不是通过一个低到高的过程恢复分辨率,因此预测的热图可能在空间上更精确。 并行高分辨率子网,每个子网分辨率不变。

(2) 大多数现有的融合方案都将低层和高层的表示集合起来。相反,我们使用重复的多尺度融合,利用相同深度和相似级别的低分辨率表示来提高高分辨率表示,反之亦然,从而使得高分辨率表示对于姿态的估计也很充分。因此,我们预测的热图可能更准确 。多尺度融合,多个子网之间相互融合,信息融合非常充分。

代码:
pytorch实现:
def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general,指定yaml文件的路径
    parser.add_argument('--cfg',
                        help='experiment configure file name',
                        required=True,
                        type=str)

    # 暂时没有具体实现
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    # philly
    # 模型的目录
    parser.add_argument('--modelDir',
                        help='model directory',
                        type=str,
                        default='')

    # 输出log的目录
    parser.add_argument('--logDir',
                        help='log directory',
                        type=str,
                        default='')

    # 训练数据的目录
    parser.add_argument('--dataDir',
                        help='data directory',
                        type=str,
                        default='')

    # 预训练模型的目录
    parser.add_argument('--prevModelDir',
                        help='prev Model directory',
                        type=str,
                        default='')


    args = parser.parse_args()

    return args


def main():
    # 对输入参数进行解析
    args = parse_args()
    # 根据输入参数对cfg进行更新
    update_config(cfg, args)

    # 创建logger,用于记录训练过程的打印信息
    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)


    # cudnn related setting
    # 使用GPU的一些相关设置
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    # 根据配置文件构建网络
    print('models.'+cfg.MODEL.NAME+'.get_pose_net')
    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True
    )

    # copy model file,拷贝lib/models/pose_hrnet.py文件到输出目录之中
    this_dir = os.path.dirname(__file__)
    print(os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'))
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    # 用于训练信息的图形化显示
    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # 用于模型的图形化显示
    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0])
    )
    #writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))



    # 让模型支持多GPU训练
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()


    # define loss function (criterion) and optimizer,用于计算loss
    criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT
    ).cuda()


    # Data loading code,对输入图象数据进行正则化处理
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    # 创建训练以及测试数据的迭代器
    train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )
    valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY
    )


    # 模型加载以及优化策略的相关配置
    best_perf = 0.0 #
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(
        final_output_dir, 'checkpoint.pth'
    )

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
        last_epoch=last_epoch
    )

    #循环迭代进行训练
    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)


        # evaluate on validation set
        perf_indicator = validate(
            cfg, valid_loader, valid_dataset, model, criterion,
            final_output_dir, tb_log_dir, writer_dict
        )

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint({
            'epoch': epoch + 1,
            'model': cfg.MODEL.NAME,
            'state_dict': model.state_dict(),
            'best_state_dict': model.module.state_dict(),
            'perf': perf_indicator,
            'optimizer': optimizer.state_dict(),
        }, best_model, final_output_dir)

    final_model_state_file = os.path.join(
        final_output_dir, 'final_state.pth'
    )
    logger.info('=> saving final model state to {}'.format(
        final_model_state_file)
    )
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()


if __name__ == '__main__':
    main()
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值