SPIN流程

# No reduction because confidence weighting needs to be applied
self.criterion_keypoints = nn.MSELoss(reduction='none').to(self.device)

keypoint_loss

def keypoint_loss(self, pred_keypoints_2d, gt_keypoints_2d, openpose_weight, gt_weight):
""" Compute 2D reprojection loss on the keypoints.
The loss is weighted by the confidence.
The available keypoints are different for each dataset.
"""
	conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
    conf[:, :25] *= openpose_weight
    conf[:, 25:] *= gt_weight
    loss = (conf * self.criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean()
    return loss

keypoint_3d_loss

def keypoint_3d_loss(self, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d):
"""Compute 3D keypoint loss for the examples that 3D keypoint annotations are available.
The loss is weighted by the confidence.
"""
pred_keypoints_3d = pred_keypoints_3d[:, 25:, :]
conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone()
gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1]
conf = conf[has_pose_3d == 1]
pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1]
gt_pelvis = (gt_keypoints_3d[:, 2,:] + gt_keypoints_3d[:, 3,:]) / 2
gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :]
pred_pelvis = (pred_keypoints_3d[:, 2,:] + pred_keypoints_3d[:, 3,:]) / 2
pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :]
return (conf * self.criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean()

shape loss

# Per-vertex loss on the shape
self.criterion_shape = nn.L1Loss().to(self.device)
def shape_loss(self, pred_vertices, gt_vertices, has_smpl):
"""Compute per-vertex loss on the shape for the examples that SMPL annotations are available."""
	return self.criterion_shape(pred_vertices_with_shape, gt_vertices_with_shape)

smpl_losses

# Loss for SMPL parameter regression
self.criterion_regr = nn.MSELoss().to(self.device)

def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl):
	pred_rotmat_valid = pred_rotmat[has_smpl == 1]
    gt_rotmat_valid = batch_rodrigues(gt_pose.view(-1,3)).view(-1, 24, 3, 3)[has_smpl == 1]
    pred_betas_valid = pred_betas[has_smpl == 1]
    gt_betas_valid = gt_betas[has_smpl == 1]
    loss_regr_pose = self.criterion_regr(pred_rotmat_valid, gt_rotmat_valid)
    loss_regr_betas = self.criterion_regr(pred_betas_valid, gt_betas_valid)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值