HRNet

本文详细分析了HRNet的源码,包括数据准备、模型设计、训练阶段以及数据流。从mpii.py和JointsDataset.py的数据预处理到模型的构建,如BasicBlock、Bottleneck和HighResolutionModule。通过对各模块的深入理解,揭示了HRNet在人体关键点检测中的工作原理和实现细节。
摘要由CSDN通过智能技术生成

在这里插入图片描述

论文链接:https://arxiv.org/abs/1902.09212

代码链接:https://github.com/leoxiaobin/deep-high-resolution-net.pytorch

论文源码分析:

1 源码准备

在指定文件夹下,输入命令:

git clone https://github.com/leoxiaobin/deep-high-resolution-net.pytorch.git

下载完成后,得到HRNet源码

2 源码结构

下表列出HRNet中比较重要的文件:

文件名称 功能
tools/trian.py 训练脚本
tools/test.py 测试脚本
lib/dataset/mpii.py 对MPII数据集进行预处理
lib/dataset/JointsDataSet 数据读取脚本
lib/models/pose_hrnet.py 网络结构构建脚本
lib/utils HRNet的一些方法
experiments/mpii/hrnet HRNet网络的初始化参数脚本

接下来对一些重要文件,将一一讲解,并且说清数据流的走向和函数调用关系。

3 源码分析(准备阶段)

3.1 数据准备

3.1.1 mpii.py

通过阅读源码可以知道,通过mpii.py文件中的MPIIDataset的初始化函数,将获得一个rec的数据,其中包含:coco中所有人体,对应关键点的信息、图片路径、标准化以及缩放比例等信息。

(1) _init_函数

class MPIIDataset(JointsDataset):
    def __init__(self, cfg, root, image_set, is_train, transform=None):
        super().__init__(cfg, root, image_set, is_train, transform)

        self.num_joints = 16
        self.flip_pairs = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
        self.parent_ids = [1, 2, 6, 6, 3, 4, 6, 6, 7, 8, 11, 12, 7, 7, 13, 14]

        self.upper_body_ids = (7, 8, 9, 10, 11, 12, 13, 14, 15)
        self.lower_body_ids = (0, 1, 2, 3, 4, 5, 6)

        self.db = self._get_db()

        if is_train and cfg.DATASET.SELECT_DATA:
            self.db = self.select_data(self.db)

        logger.info('=> load {} samples'.format(len(self.db)))

MPIIDataSet类的初始化方法_init_需要如下参数:

  • num_joints : MPII数据集中人体关键点标记个数
  • flip_pairs : 人体水平对称关键映射
  • parents_ids : 父母ids
  • upper_body_ids : 定义上半身关键点
  • lower_body_ids : 定义下半身关键点
  • db : 读取目标检测模型

(2) _get_db函数

def _get_db(self):
        # create train/val split
        file_name = os.path.join(
            self.root, 'annot', self.image_set+'.json'
        )
        with open(file_name) as anno_file:
            anno = json.load(anno_file)

        gt_db = []
        for a in anno:
            image_name = a['image']

            c = np.array(a['center'], dtype=np.float)
            s = np.array([a['scale'], a['scale']], dtype=np.float)
            if c[0] != -1:
                c[1] = c[1] + 15 * s[1]
                s = s * 1.25
            c = c - 1

            joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
            joints_3d_vis = np.zeros((self.num_joints,  3), dtype=np.float)
            if self.image_set != 'test':
                joints = np.array(a['joints'])
                joints[:, 0:2] = joints[:, 0:2] - 1
                joints_vis = np.array(a['joints_vis'])
                assert len(joints) == self.num_joints, \
                    'joint num diff: {} vs {}'.format(len(joints),
                                                      self.num_joints)

                joints_3d[:, 0:2] = joints[:, 0:2]
                joints_3d_vis[:, 0] = joints_vis[:]
                joints_3d_vis[:, 1] = joints_vis[:]

            image_dir = 'images.zip@' if self.data_format == 'zip' else 'images'
            gt_db.append(
                {
   
                    'image': os.path.join(self.root, image_dir, image_name),
                    'center': c,
                    'scale': s,
                    'joints_3d': joints_3d, 
                    'joints_3d_vis': joints_3d_vis,
                    'filename': '',
                    'imgnum': 0,
                }
            )

        return gt_db

首先找到MPII数据集的分割依据文件annotaion,之后循环遍历该数据集,读取每张图片的名称、中心点位置、大小、人体关键节点位置(用三维坐标表示)、可见的人体关键节点位置并保存,形成一个字典不断加入到gt_db,循环结束返回。数据预处理到这并没有结束,因为还需要进一步处理,原因在于当计算loss的时候,我们需要的是热图(heatmap)。

3.1.2 JointsDataset.py

接下来,我们需要根据get_db中的信息,读取图片像素(用于训练),同时把标签信息转化为heatmap。

(1) init.py

class JointsDataset(Dataset):
    def __init__(self, cfg, root, image_set, is_train, transform=None):
        self.num_joints = 0# 人体关节的数目
        self.pixel_std = 200# 像素标准化参数
        self.flip_pairs = []# 水平翻转
        self.parent_ids = []# 父母ID==

        self.is_train = is_train# 是否进行训练
        self.root = root# 训练数据根目录
        self.image_set = image_set# 图片数据集名称,如‘train2017’

        self.output_path = cfg.OUTPUT_DIR# 输出目录
        self.data_format = cfg.DATASET.DATA_FORMAT# 数据格式如‘jpg’

        self.scale_factor = cfg.DATASET.SCALE_FACTOR# 缩放因子
        self.rotation_factor = cfg.DATASET.ROT_FACTOR # 旋转角度
        self.flip = cfg.DATASET.FLIP# 是否进行水平翻转
        self.num_joints_half_body = cfg.DATASET.NUM_JOINTS_HALF_BODY# 人体一半关键点的数目,默认为8
        self.prob_half_body = cfg.DATASET.PROB_HALF_BODY# 人体一半的概率
        self.color_rgb = cfg.DATASET.COLOR_RGB# 图片格式,默认为rgb

        self.target_type = cfg.MODEL.TARGET_TYPE# 目标数据的类型,默认为高斯分布
        self.image_size = np.array(cfg.MODEL.IMAGE_SIZE)# 网络训练图片大小,如[192,256]
        self.heatmap_size = np.array(cfg.MODEL.HEATMAP_SIZE)# 标签热图的大小
        self.sigma = cfg.MODEL.SIGMA# sigma参数,默认为2
        self.use_different_joints_weight = cfg.LOSS.USE_DIFFERENT_JOINTS_WEIGHT# 是否对每个关节使用不同的权重,默认为false
        self.joints_weight = 1# 关节权重

        self.transform = transform# 数据增强,转换等
        self.db = []# 用于保存训练数据的信息,由子类提供

_init_函数的功能在于初始化JointsDataset模型,设置一些参数和参数默认值,每个参数值的作用已经注释。通过这些初始化操作,可以获得一些基本信息,如人体关节数目、图片格式、标签热图的大小、关节权重等。

(2) _getitem_函数

	def __getitem_(self,idx):	
        db_rec = copy.deepcopy(self.db[idx])
        image_file = db_rec['image']
        filename = db_rec['filename'] if 'fename' in db_rec else ''
        imgnum = db_rec['imgnum'] if 'imgnum' in db_rec else ''
        if self.data_format == 'zip':
            from utils import zipreader
            data_numpy = zipreader.imread(
                image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
            )
        else:
            data_numpy = cv2.imread(
                image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
            )

        if self.color_rgb:
            data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
        if data_numpy is None:
            logger.error('=> fail to read {}'.format(image_file))
            raise ValueError('Fail to read {}'.format(image_file))
	
        joints = db_rec['joints_3d']# 人体3d关键点的所有坐标
        joints_vis = db_rec['joints_3d_vis']# 人体3d关键点的所有可视坐标

        # 获取训练样本转化之后的center以及scale,
        c = db_rec['center']
        s = db_rec['scale']
        
        # 如果训练样本中没有设置score,则加载该属性,并且设置为1
        score = db_rec['score'] if 'score' in
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值