《SRN: Stacked Regression Network for Real-time 3D Hand Pose Estimation》略读与实践
这篇与《AWR: Adaptive Weighting Regression for 3D Hand Pose Estimation》相似,本质都还是3Dheatmap用来做深度图手部关键点坐标估计,最大不同在于这篇论文是迭代stage的方法不断finetune回归关键点坐标的。
作者认为去掉解码网络结构,模型会更小更快,效果不一定会变差,因为之前典型的具有stage的模型结构特征提取部分都是hourglass结构的,涉及编码和解码两个过程。同时作者还认为dense pixel-wise(heatmap)方案的解码部分并不是很有效,太关注局部特征,如果标签坐标附近有深度值缺失,效果更不可靠。因此需要一种可微分的参数可重置或复用的模块来直接提取空间特征。
官方开源代码:https://github.com/RenFeiTemp/SRN
废话不多说,上整体结构示意图:
上图清晰可以看出为了干掉解码结构,每个stage都复用最开始的特征提取模块和深度图数据,结构的核心和难点就是理解上图中的橘黄色模块"Regression module"。
作者也是使用了3D offset vector heatmap和概率heatmap,采用smooth L1 loss。
下面看下官方模型测试效果(妖哥亲测):
上面四个结果均为最后一个stage的结果,原文中代码有三个stage,最后一个是效果最好的,看上去感觉还可以。
废话不多说,show your code!(核心代码)
def joint2offset(self,joint,img,feature_size=32): #如何将关节点坐标转为3Dheatmap
device = joint.device
batch_size,_,img_height,img_width = img.size() # B X 128 X 128
#print(' batch_size,_,img_height,img_width ', batch_size, ' ', img_height,' ', img_width)
img = F.interpolate(img, size=[feature_size, feature_size]) # 32 X 32
#print('img ', img.size())
_,joint_num,_ = joint.view(batch_size,-1,3).size() # joint shape 1 X 21 X 3
joint_feature = joint.view(joint.size(0),-1,1,1).repeat(1, 1, feature_size, feature_size) # 63 X 32 X 32
#print('joint_feature ', joint_feature.size(), ' ',joint_feature)
mesh_x = 2.0 * torch.arange(feature_size).unsqueeze(1).expand(feature_size, feature_size).float() / (feature_size -