目录
论文简介
论文标题:AutoLink: Self-supervised Learning of Human Skeletons and Object Outlines by Linking Keypoints
(AutoLink:基于关键点链接的人体骨架和物体轮廓的自监督学习)
期刊:NeurIPS 2022 (Spotlight)
作者:Xingzhe He Bastian Wandt Helge Rhodin
原文链接:https://arxiv.org/pdf/2205.10636.pdf
论文摘要
关键点检测的应用:关键点等结构化表示广泛应用于姿态转换、条件图像生成、动画和3D重建
需要解决的问题:全监督学习的关键点检测标注费时费力,而且不准确
本文提出:一种自监督方法,该方法学习用由直边连接的二维关键点图从外观中分离出物体的结构
训练数据:大量无标记的数据样本
得到的结果:得到关键点位置,和两两连接它们的边的权值
优势:生成的图形是可解释的(interpretable),例如,当应用于显示人物的图像时,AutoLink可以恢复人体骨架拓扑结构
算法重要内容:i)一种编码器:预测关键点在输入图像中的位置;ii)一个共享图:作为潜在变量,链接每一张图像中的相同的关键点对;iii)一种中间边图(intermediate edge map):以一种软的、可微的方式结合了潜在图的边缘权重和关键点位置;iv)根据随机掩码图像进行图像重建
评价:虽然更简单,但AutoLink在已建立的关键点和姿态估计基准上优于现有的自监督方法,并为在更多样化的数据集上使用结构条件生成模型铺平了道路。
总结一下这篇论文的重点(通俗理解)
论文核心和目标
论文核心要点
- 实现自监督的方法:引入图像重建作为辅助任务(pretext task),设计损失函数为VGG感知损失函数(Perceptual Loss),模型优化目标是让感知损失函数最小化,即让重建后的图像尽可能与输入图像在感知上一致
- 骨架边图(edge map)权值共享
P.S. 辅助任务(pretext task)是自监督学习中常见的部分,作用还是为主要的训练任务服务。本文中图像重建不是该算法的最终目的,只是通过图像重建来计算损失函数,进而实现整体优化目标。
常见的辅助任务有:解决Jigsaw Puzzles拼图问题、恢复图片损失部分、将黑白图片转成彩色、使用输入的一个channel取预测另一个channel、随机打乱视频帧学习对它们排序 等等
核心思想
- 在同类物体的所有不同图片样例中(比如人体姿态,人物身材穿着不同,动作也不同),通过引入一个显式图(explicit graph)来链接所有样例中的相同的关键点对(比如头、脖子、肩膀、手肘、膝盖等,是所有人共有的关键点对),从而利用相同对象共享相同的拓扑结构 P.S.而其他方法都是将对象建模为一系列独立的部分。
- 关键点是每张图片特定的,而边图是共享的
- 通过学习具有一组边权值(edge weights)的共享图(gragh),在同一类的实例之间强制执行相同的拓扑结构
- 通过迫使模型生成结构信息(关键点和骨架边图)来重建原始图像,检测器会收敛以生成具有代表性的图像结构
目标
论文的目标是从大量的无标记数据集中,用自监督的方法训练出有出色结构识别能力的模型,输入图像可以得到一致且可解释的关键点和骨架边图。
算法流程
将128x128分辨率的图像输入到Detector得到关键点坐标张量,计算骨架边图 (edge map) ;之后,将输入图像进行随机掩码80%,与edge map一同输入进Decoder,进行图像重建;最后计算VGG感知损失,进行迭代优化。
使用Adam优化器 (Adam optimizer) ,学习率为,batch size为64,训练20k iterations
注意,上图右边learned graph edge weights的虚线是连到了多张图片的edge map绘制过程中的,这表示learned graph edge weights是由同种类多张不同样例共享的
在图像重建任务中,结构边图(edge map)仅提供几何信息,随机掩码后的图像(masked image)仅提供外观(appearance)信息。
所以说,虽然是利用检测器检测的关键点和边图信息来进行图像重建,但图像重建的目的还是为了让模型更好的训练出关键点和骨架结构识别能力。
为什么需要随机掩码?
我的理解是让随机掩码后的图像(masked image)仅提供外观信息,而不让其反映结构信息。目的是要让编码器独立自主地训练自己的结构识别能力,在图像重建的过程中,让解码器利用edge map来获取物体结构信息而不依赖于输入图像。
若对输入图像不进行随机掩码,直接陪同edge map输入给解码器来重建图像,会使得在训练初期,重建后的图像在结构上就会与输入图像非常相似,使损失值变得很小,引起训练误差很小但模型泛化能力较差(模型对陌生图片的结构识别能力较差),即过拟合现象。
一致性检验
MAFL是celebA数据集中的一部分,附带有人脸五官的5个关键点标注。
MAFL训练集:训练一个从检测到的关键点到MAFL训练集上Ground Truth的线性回归
MAFL测试集:通过眼间距离归一化得到的平均L2误差
用于评价模型训练后的检测关键点的一致性情况
代码研读
坐标网格生成
目标:生成论文中提到的 归一化的像素坐标网格
编码器Encoder
编码器encoder.py文件中,Encoder模型包含了Detector模型对输入图片检测热图和关键点,还有对输入图片随机掩码的操作。
1. Keypoint Detector模型代码:主要是一个ResNet网络,包含ResBlock和TransposedBlock两类,分别写两个Block的类,然后在Detector类里进行顺序堆叠,结尾多加一个卷积层
self.conv = nn.Sequential(
ResBlock(3, 64), # 输出图像大小(64,64,64)
ResBlock(64, 128), # (32,32,128)
ResBlock(128, 256), # (16,16,256)
ResBlock(256, 512), # (8,8,512)
TransposedBlock(512, 256), # (16,16,256)
TransposedBlock(256, 128), # (32,32,128)
nn.Conv2d(128, self.n_parts, kernel_size=3, padding=1), # (32,32,4)
)
2. Detector的前向传播:计算 heatmap 和 keypoint
如下方公式,是一个soft-argmax。假设训练的batch_size = 64,用ResNet拟合H(img)映射,公式里H(p)表示是输入(64,3,128,128)图片得到对应的(64,4,32,32)的热图,其中,H(p)在代码里为prob_map。然后将每张图片每数据都放在第三维(行),得到(64,4,1024,1)格式的热图,之后进行softmax。然后按行相加,得到(64,4,2)大小的关键点张量,存放64张图片4个关键点的xy坐标。
P.S 原文中写的上面这个公式,我觉得跟代码有些出入——代码通过上面提到的自定义ResNet网络,self.conv(img) 得到heatmap,而原文公式是H(p),p是像素坐标张量,不太理解。
应该是,H映射表示输入原始图片,然后得到热图(heatmap)H(img)
def forward(self, input_dict: dict) -> dict:
# F.interpolate(Tensor, ...) 确保图片大小为128x128
# 插值后img的大小仍然为 torch.Size([64, 3, 128, 128])
img = F.interpolate(input_dict['img'], size=(128, 128), mode='bilinear', align_corners=False)
# self.conv(img).shape >> torch.Size([64, 4, 32, 32])
# prob_map.shape >> (64, 4, 32*32=1024, 1)
# reshape是默认将第三第四维的矩阵按行取值组成列向量
prob_map = self.conv(img).reshape(img.shape[0], self.n_parts, -1, 1)
prob_map = F.softmax(prob_map, dim=2)
# 张量相乘后 (1,1,1024,2) * (64,4,1024,1) = (64,4,1024,2) 相乘时会广播到缺少的维度,然后逐元素相乘
# 将self.coord 视为可学习的权重 prob_map视为列向量(以第三四维来看)
keypoints = self.coord * prob_map
# torch.sum 若不指明维度,则将所有元素求和为一个数
# 四维张量可以想像为64个(4x1024x2)的正方体 将每个正方体按行相加 行维度消失 64个(4x2)的矩阵 重新组成为 (64,4,2)张量
keypoints = keypoints.sum(dim=2)
# (64,4,32,32)
prob_map = prob_map.reshape(keypoints.shape[0], self.n_parts, self.output_size, self.output_size)
return {'keypoints': keypoints, 'prob_map': prob_map}
3. Encoder的前向传播:包含了Detector的前向传播操作,并且还会做输入图像的随机掩码
4. 随机掩码:
对输入图像进行随机掩码,将128x128的图像分为16个方格,随机遮蔽80%的方格
已设定batch_size = 64,self.missing = 0.8 ,则
- torch.zeros().uniform_() 生成大小为(64,1,16,16)的、元素随机取自[0,1]均匀分布的张量torch.zeros().uniform_() > self.missing 表示张量中元素大于0.8则取True,否则取False
- 然后,F.interpolate(damage_mask.to(input_dict['img']), size= 128 , mode = 'nearest'),采用nearest算法做上采样,从16x16上采样为128x128图像大小。damage_mask.to(input_dict['img']) 会将张量中元素从布尔值转为浮点型的0和1,同时device也会转为gpu。每一个大格中8x8个像素点都为同一个值(0或1)
- 最后,将图像张量 按相应元素 乘以damage_mask,就会有80%大格,大格子内的所有像素要取0,其余的保留原值
当然这一操作是在同一batch内,分图片分通道进行的,两个张量相乘 (64,3,128,128) * (64,1,128,128) ,damage_mask第二维会广播到3,即输入图片三个颜色通道都会进行随机掩码
'''例:
a = torch.zeros(1, 1, 2, 3)
b = a.uniform_() > 0.8
print(a, "\n", b)
tensor([[[[0.8816, 0.8507, 0.3285],
[0.3228, 0.5518, 0.9808]]]])
tensor([[[[ True, True, False],
[False, False, True]]]])'''
# damage_mask >> (64,1,16,16) .uniform_() 从(0,1)均匀分布里对tensor所有元素随机重新取值
damage_mask = torch.zeros(input_dict['img'].shape[0], 1, self.block, self.block, device=input_dict['img'].device).uniform_() > self.missing
# damage_mask.to(input_dict['img'])是在将damage_mask的device和dtype转成与其相同
# 将128x128分为16x16的格子 每个格子包含的8x8个像素点取同一个值
damage_mask = F.interpolate(damage_mask.to(input_dict['img']), size=input_dict['img'].shape[-1], mode='nearest')
# 元素对应相乘
mask_batch['damaged_img'] = input_dict['img'] * damage_mask
解码器Decoder
解码器decoder.py文件主要包含 计算最终骨架边图、图像重建等操作
1. 绘制可微的骨架边图 (edge map) Sij
编码器得到所有关键点坐标后,假设取两个关键点和。在两个关键点连接处Sij为1,而随像素点到edge直线的距离指数性地减少。形式上,骨架边图Sij是沿直线的高斯扩展。
其中,是为骨架边 (edge) 的厚度,为像素张量到连接和的骨架边的L2距离,t是与在edge上投影点的归一化距离
以下代码的draw_lines函数,输入paired_joints张量(对每张图片,存放6组,每组两对关键点的坐标 ),得到Sij,注意绘制的Sij也就是分散的edge map是每张16x16大小的
def draw_lines(paired_joints: torch.Tensor, heatmap_size: int=16, thick: Union[float, torch.Tensor]=1e-2) -> torch.Tensor:
"""
Draw lines on a grid.
:param paired_joints: (batch_size, n_points, 2, 2)
:return: (batch_size, n_points, grid_size, grid_size)
dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
"""
# bs为batch_size ,n_points为关键点数量 bs = 64 n_points = 4
bs, n_points, _, _ = paired_joints.shape # .shape得到张量的形状,返回元组形式(bs,n_points,2,2)
start = paired_joints[:, :, 0, :] # (batch_size, n_points, 2)
end = paired_joints[:, :, 1, :] # (batch_size, n_points, 2)
paired_diff = end - start # (batch_size, n_points, 2)
# (1,1,16*16,2)
grid = gen_grid2d(heatmap_size).to(paired_joints.device).reshape(1, 1, -1, 2)
# .unsqueeze(-2) 增维 表示在倒数第2维插入一个维度(即,在第三维插入一个维度)
# (1,1,16*16,2) - (64,4,?,2) = (64,4,1024,2)
diff_to_start = grid - start.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
# (batch_size, n_points, heatmap_size**2)
# @为矩阵乘法
t = (diff_to_start @ paired_diff.unsqueeze(-1)).squeeze(-1) / (1e-8+paired_diff.square().sum(dim=-1, keepdim=True))
diff_to_end = grid - end.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
# .sum() 对于二维数组,dim=0,就是将所有行合并成一行;dim=1,就是将所有列合并成一列,
# 对于三维张量,dim=0对应第三维,dim=1是行, dim=2是列 dim=-1表示倒数第一维即dim=2
before_start = (t <= 0).float() * diff_to_start.square().sum(dim=-1) # .square() 张量各元素取平方
after_end = (t >= 1).float() * diff_to_end.square().sum(dim=-1)
between_start_end = (0 < t).float() * (t < 1).float() * (grid - (start.unsqueeze(-2) + t.unsqueeze(-1) * paired_diff.unsqueeze(-2))).square().sum(dim=-1)
squared_dist = (before_start + after_end + between_start_end).reshape(bs, n_points, heatmap_size, heatmap_size)
heatmaps = torch.exp(- squared_dist / thick) #负号?
return heatmaps
对于求解dij的原理,我自己的理解如下:
为每一条边分配权重,这个权重是在训练中可学习的,并且是同一数据集内所有不同实例对应的edge所共享的
最后,取多张heatmap中每一个像素点的最大值,获得最终的热图
在Decoder模型代码中, skeleton_idx为骨骼边的索引,skeleton_scalar为(4x4)的边共享权重矩阵
# torch.triu_indices返回上三角矩阵元素的索引,第一行为索引的行坐标,第二行为索引的列坐标
# self.skeleton_idx为 (2,6)大小的张量
self.skeleton_idx = torch.triu_indices(self.n_parts, self.n_parts, offset=1)
self.n_skeleton = len(self.skeleton_idx[0]) # 6
self.alpha = nn.Parameter(torch.tensor(1.0), requires_grad=True)
# 生成4x4张量 元素满足随机标准正态分布N(0,1)
skeleton_scalar = (torch.randn(self.n_parts, self.n_parts) / 10 - 4) / self.sklr
self.skeleton_scalar = nn.Parameter(skeleton_scalar, requires_grad=True)
代码中rasterize函数就是通过输入keypoints得到最终的骨架边图(在代码里是skeleton_heatmap),torch.triu保留skeleton_scalar上三角元素(不含主对角线元素),其余取0
def rasterize(self, keypoints: torch.Tensor, output_size: int=128) -> torch.Tensor:
"""
Generate edge heatmap from keypoints, where edges are weighted by the learned scalars.
:param keypoints: (batch_size, n_points, 2)
:return: (batch_size, 1, heatmap_size, heatmap_size)
"""
# paired_joints大小为(64,6,2,2)
paired_joints = torch.stack([keypoints[:, self.skeleton_idx[0], :2], keypoints[:, self.skeleton_idx[1], :2]], dim=2)
# 4x4
skeleton_scalar = F.softplus(self.skeleton_scalar * self.sklr)
# diagonal=1 不包含主对角线上的元素
skeleton_scalar = torch.triu(skeleton_scalar, diagonal=1)
# skeleton_scalar[self.skeleton_idx[0], self.skeleton_idx[1]] 将上三角元素依次转为行向量
# (1,6,1,1) 分别对应6个边连接组合
skeleton_scalar = skeleton_scalar[self.skeleton_idx[0], self.skeleton_idx[1]].reshape(1, self.n_skeleton, 1, 1)
skeleton_heatmap_sep = draw_lines(paired_joints, heatmap_size=output_size, thick=self.thick)
# skeleton_scalar为权重 每个像素点的edge map值乘以权重矩阵的对应值
# (64,6,128,128) = (64,6,128,128) * (1,6,1,1)
skeleton_heatmap_sep = skeleton_heatmap_sep * skeleton_scalar.reshape(1, self.n_skeleton, 1, 1)
# torch.max()[0] ,只返回最大值的每个数 torch.max()[1]为最大值索引
# (64,1,128,128)
skeleton_heatmap = skeleton_heatmap_sep.max(dim=1, keepdim=True)[0]
return skeleton_heatmap
其中,paired_joints是4个关键点相互连接的6组配对形式对应的关键点坐标,6组配对形式如前面提到的4x4上三角矩阵上三角元素(不含主对角线元素)的索引
torch.stack操作与torch.cat不同点在于,前者合并的维度是新增的维度 ,比如两个2维张量stack会变成一个3维张量。下面是torch.cat()和torch.stack()的例子:
paired_joints的获取过程如下,可见,keypoints变换操作后的两组 (64,6,2) 的三维张量,经过torch.stack后得到 (64,6,2,2) 大小的四维张量
得到的paired_joints张量每个小2x2方块表示两个相互连接的关键点各自的坐标
对于toch.max(dim = 1,keepdim = True)[0]函数,作用如下:
2. 图像重建
从下方代码中Decoder模型的前向传播可看出,先通过写好的rasterize()函数,通过关键点绘制edge map(代码里是skeleton_heatmap)。然后根据下方公式,将随机掩码图像和edge map按比例连接,输入到UNet进行图像重建。注意,这里的连接是利用torch.cat将两个大小分别为(64,3,128,128)和(64,1,128,128)的张量按第二维连接成(64,4,128,128)
Decoder模型中的网络定义为如下:
self.down0 = nn.Sequential(
nn.Conv2d(3 + 1, 64, kernel_size=(3, 3), padding=1),
nn.LeakyReLU(0.2, True),)
self.down1 = DownBlock(64, 128) # 64
self.down2 = DownBlock(128, 256) # 32
self.down3 = DownBlock(256, 512) # 16
self.down4 = DownBlock(512, 512) # 8
self.up1 = UpBlock(512, 512) # 16
self.up2 = UpBlock(512 + 512, 256) # 32
self.up3 = UpBlock(256 + 256, 128) # 64
self.up4 = UpBlock(128 + 128, 64) # 64
self.conv = nn.Conv2d(64+64, 3, kernel_size=(3, 3), padding=1)
Decoder模型的前向传播如下,可见模型得到的输出为总体的edge map和重建后图片
def forward(self, input_dict: dict) -> dict:
# (64,1,16,16)
skeleton_heatmap = self.rasterize(input_dict['keypoints'])
# torch.cat([(64,3,128,128),(64,1,128,128)])
# x为(64,4,128,128)
x = torch.cat([input_dict['damaged_img'] * self.alpha, skeleton_heatmap], dim=1) #按第二维(深度)拼接
down_128 = self.down0(x)
down_64 = self.down1(down_128)
down_32 = self.down2(down_64)
down_16 = self.down3(down_32)
down_8 = self.down4(down_16)
up_8 = down_8
up_16 = torch.cat([self.up1(up_8), down_16], dim=1)
up_32 = torch.cat([self.up2(up_16), down_32], dim=1)
up_64 = torch.cat([self.up3(up_32), down_64], dim=1)
up_128 = torch.cat([self.up4(up_64), down_128], dim=1)
img = self.conv(up_128) #I' 重建后图片
input_dict['heatmap'] = skeleton_heatmap
input_dict['img'] = img
return input_dict
损失函数
感知损失函数(VGG Perceptual Loss)可以衡量结构的相似性。公式如下,为一个batch中的图片数,为特征提取器
以下为特征提取器的网络,在VGG Perceptual Loss模型的前向传播里,对原始输入图片和重建图片做公式里的操作,即可得到损失值
blocks = [torchvision.models.vgg16(weights='DEFAULT').features[:4].eval(),
torchvision.models.vgg16(weights='DEFAULT').features[4:9].eval(),
torchvision.models.vgg16(weights='DEFAULT').features[9:16].eval(),
torchvision.models.vgg16(weights='DEFAULT').features[16:23].eval()]