论文链接:https://arxiv.org/abs/1902.09212
代码链接:https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
论文源码分析:
1 源码准备
在指定文件夹下,输入命令:
git clone https://github.com/leoxiaobin/deep-high-resolution-net.pytorch.git
下载完成后,得到HRNet源码
2 源码结构
下表列出HRNet中比较重要的文件:
文件名称 | 功能 |
---|---|
tools/trian.py | 训练脚本 |
tools/test.py | 测试脚本 |
lib/dataset/mpii.py | 对MPII数据集进行预处理 |
lib/dataset/JointsDataSet | 数据读取脚本 |
lib/models/pose_hrnet.py | 网络结构构建脚本 |
lib/utils | HRNet的一些方法 |
experiments/mpii/hrnet | HRNet网络的初始化参数脚本 |
接下来对一些重要文件,将一一讲解,并且说清数据流的走向和函数调用关系。
3 源码分析(准备阶段)
3.1 数据准备
3.1.1 mpii.py
通过阅读源码可以知道,通过mpii.py文件中的MPIIDataset的初始化函数,将获得一个rec的数据,其中包含:coco中所有人体,对应关键点的信息、图片路径、标准化以及缩放比例等信息。
(1) _init_函数
class MPIIDataset(JointsDataset):
def __init__(self, cfg, root, image_set, is_train, transform=None):
super().__init__(cfg, root, image_set, is_train, transform)
self.num_joints = 16
self.flip_pairs = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
self.parent_ids = [1, 2, 6, 6, 3, 4, 6, 6, 7, 8, 11, 12, 7, 7, 13, 14]
self.upper_body_ids = (7, 8, 9, 10, 11, 12, 13, 14, 15)
self.lower_body_ids = (0, 1, 2, 3, 4, 5, 6)
self.db = self._get_db()
if is_train and cfg.DATASET.SELECT_DATA:
self.db = self.select_data(self.db)
logger.info('=> load {} samples'.format(len(self.db)))
MPIIDataSet类的初始化方法_init_需要如下参数:
- num_joints : MPII数据集中人体关键点标记个数
- flip_pairs : 人体水平对称关键映射
- parents_ids : 父母ids
- upper_body_ids : 定义上半身关键点
- lower_body_ids : 定义下半身关键点
- db : 读取目标检测模型
(2) _get_db函数
def _get_db(self):
# create train/val split
file_name = os.path.join(
self.root, 'annot', self.image_set+'.json'
)
with open(file_name) as anno_file:
anno = json.load(anno_file)
gt_db = []
for a in anno:
image_name = a['image']
c = np.array(a['center'], dtype=np.float)
s = np.array([a['scale'], a['scale']], dtype=np.float)
if c[0] != -1:
c[1] = c[1] + 15 * s[1]
s = s * 1.25
c = c - 1
joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float)
if self.image_set != 'test':
joints = np.array(a['joints'])
joints[:, 0:2] = joints[:, 0:2] - 1
joints_vis = np.array(a['joints_vis'])
assert len(joints) == self.num_joints, \
'joint num diff: {} vs {}'.format(len(joints),
self.num_joints)
joints_3d[:, 0:2] = joints[:, 0:2]
joints_3d_vis[:, 0] = joints_vis[:]
joints_3d_vis[:, 1] = joints_vis[:]
image_dir = 'images.zip@' if self.data_format == 'zip' else 'images'
gt_db.append(
{
'image': os.path.join(self.root, image_dir, image_name),
'center': c,
'scale': s,
'joints_3d': joints_3d,
'joints_3d_vis': joints_3d_vis,
'filename': '',
'imgnum': 0,
}
)
return gt_db
首先找到MPII数据集的分割依据文件annotaion,之后循环遍历该数据集,读取每张图片的名称、中心点位置、大小、人体关键节点位置(用三维坐标表示)、可见的人体关键节点位置并保存,形成一个字典不断加入到gt_db,循环结束返回。数据预处理到这并没有结束,因为还需要进一步处理,原因在于当计算loss的时候,我们需要的是热图(heatmap)。
3.1.2 JointsDataset.py
接下来,我们需要根据get_db中的信息,读取图片像素(用于训练),同时把标签信息转化为heatmap。
(1) init.py
class JointsDataset(Dataset):
def __init__(self, cfg, root, image_set, is_train, transform=None):
self.num_joints = 0# 人体关节的数目
self.pixel_std = 200# 像素标准化参数
self.flip_pairs = []# 水平翻转
self.parent_ids = []# 父母ID==
self.is_train = is_train# 是否进行训练
self.root = root# 训练数据根目录
self.image_set = image_set# 图片数据集名称,如‘train2017’
self.output_path = cfg.OUTPUT_DIR# 输出目录
self.data_format = cfg.DATASET.DATA_FORMAT# 数据格式如‘jpg’
self.scale_factor = cfg.DATASET.SCALE_FACTOR# 缩放因子
self.rotation_factor = cfg.DATASET.ROT_FACTOR # 旋转角度
self.flip = cfg.DATASET.FLIP# 是否进行水平翻转
self.num_joints_half_body = cfg.DATASET.NUM_JOINTS_HALF_BODY# 人体一半关键点的数目,默认为8
self.prob_half_body = cfg.DATASET.PROB_HALF_BODY# 人体一半的概率
self.color_rgb = cfg.DATASET.COLOR_RGB# 图片格式,默认为rgb
self.target_type = cfg.MODEL.TARGET_TYPE# 目标数据的类型,默认为高斯分布
self.image_size = np.array(cfg.MODEL.IMAGE_SIZE)# 网络训练图片大小,如[192,256]
self.heatmap_size = np.array(cfg.MODEL.HEATMAP_SIZE)# 标签热图的大小
self.sigma = cfg.MODEL.SIGMA# sigma参数,默认为2
self.use_different_joints_weight = cfg.LOSS.USE_DIFFERENT_JOINTS_WEIGHT# 是否对每个关节使用不同的权重,默认为false
self.joints_weight = 1# 关节权重
self.transform = transform# 数据增强,转换等
self.db = []# 用于保存训练数据的信息,由子类提供
_init_函数的功能在于初始化JointsDataset模型,设置一些参数和参数默认值,每个参数值的作用已经注释。通过这些初始化操作,可以获得一些基本信息,如人体关节数目、图片格式、标签热图的大小、关节权重等。
(2) _getitem_函数
def __getitem_(self,idx):
db_rec = copy.deepcopy(self.db[idx])
image_file = db_rec['image']
filename = db_rec['filename'] if 'fename' in db_rec else ''
imgnum = db_rec['imgnum'] if 'imgnum' in db_rec else ''
if self.data_format == 'zip':
from utils import zipreader
data_numpy = zipreader.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
)
else:
data_numpy = cv2.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
)
if self.color_rgb:
data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
if data_numpy is None:
logger.error('=> fail to read {}'.format(image_file))
raise ValueError('Fail to read {}'.format(image_file))
joints = db_rec['joints_3d']# 人体3d关键点的所有坐标
joints_vis = db_rec['joints_3d_vis']# 人体3d关键点的所有可视坐标
# 获取训练样本转化之后的center以及scale,
c = db_rec['center']
s = db_rec['scale']
# 如果训练样本中没有设置score,则加载该属性,并且设置为1
score = db_rec['score'] if 'score' in