Real-time Deep Registration With Geodesic Loss 中 Loss 的 pytorch 实现

Real-time Deep Registration With Geodesic Loss 中 Loss 的 pytorch 实现

1. 数学原理

罗德里格旋转公式 (Rodrigues‘ rotation formula), 一个向量绕旋转轴旋转给定角度 θ 以后得到的新向量。

简而言之就是,旋转角度与四元数之间的转换

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2. pytorch实现

Paper: Real-time Deep Pose Estimation with Geodesic Loss for Image-to-Template Rigid Registration

该论文采用网络回归刚性配准的参数,主要借鉴其loss函数。由于自己写的pytorch代码,因此改为直接可用的Loss函数。

在流形中,由上图可以直观的看出测地距离与欧式距离的区别,该论文就摒弃了大多论文中的 MSE误差 ,采用角度差异的误差来进行优化。

原论文的建模过程中有两个简化,核心的式子为
R = I 3 + s i n ( θ ) [ v ] × + ( 1 − c o s θ ) [ v ] × 2 R = I_3 + sin(\theta)[v]_× + (1-cos\theta)[v]_×^2 R=I3+sin(θ)[v]×+(1cosθ)[v]×2 t r ( R ) = 1 + 2 c o s ( θ ) tr(R) = 1 + 2cos(\theta) tr(R)=1+2cos(θ) L o s s G e o d e s i c = d ( R s , R G T ) = c o s − 1 [ t r ( R s T R G T ) − 1 2 ] Loss_{Geodesic} = d(R_s ,R_{GT}) = cos^{-1}[\frac{tr(R_s^TR_{GT}) - 1}{2}] LossGeodesic=d(Rs,RGT)=cos1[2tr(RsTRGT)1]

最终求得的是变换的角度

Code:

class GeodesicLoss(nn.Module):
    def __init__(self):
        super(GeodesicLoss, self).__init__()

    def my_R(self, x):
        R1 = torch.eye(3) + torch.sin(
            x[2]) * x[0] + (1.0 - torch.cos(x[2])) * (x[0] @ x[0])
        R2 = torch.eye(3) + torch.sin(
            x[3]) * x[1] + (1.0 - torch.cos(x[3])) * (x[1] @ x[1])

        return R1.transpose(0, 1) @ R2

    def get_theta(self, x):

        clamp_res = torch.clamp(0.5 * (x[0, 0] + x[1, 1] + x[2, 2] - 1.0),
                                -1.0 + 1e-7, 1.0 - 1e-7)
        acos_res = torch.acos(clamp_res)
        abs_res = torch.abs(acos_res)

        return abs_res

    def forward(self, y_true, y_pred):
        # skew_true: (3, 3, 3)
        # skew_pred: (3, 3, 3)
        # angle_true: (3,)    
        # angle_pred: (3,)
        # R shape: (3, 3, 3)
        angle_true = torch.sqrt(torch.sum(torch.pow(y_true, 2), axis=1))
        angle_pred = torch.sqrt(torch.sum(torch.pow(y_pred, 2), axis=1))

        # compute axes
        axis_true = F.normalize(y_true, p=2, dim=-1).view(3, 3)
        axis_pred = F.normalize(y_pred, p=2, dim=-1).view(3, 3)

        proj = torch.FloatTensor([[0, 0, 0, 0, 0, -1, 0, 1, 0],
                                  [0, 0, 1, 0, 0, 0, -1, 0, 0],
                                  [0, -1, 0, 1, 0, 0, 0, 0, 0]])

        skew_true = (axis_true @ proj).view(3, 3, 3)
        skew_pred = (axis_pred @ proj).view(3, 3, 3)

        r1 = self.my_R((skew_true[0, ...], skew_pred[0, ...], angle_true[0], angle_pred[0]))
        r2 = self.my_R((skew_true[1, ...], skew_pred[1, ...], angle_true[1], angle_pred[1]))
        r3 = self.my_R((skew_true[2, ...], skew_pred[2, ...], angle_true[2], angle_pred[2]))
        R = torch.stack([r1, r2, r3], dim=0)

        theta1 = self.get_theta(R[0, ...])
        theta2 = self.get_theta(R[1, ...])
        theta3 = self.get_theta(R[2, ...])
        theta = torch.stack([theta1, theta2, theta3], dim=0)
        return torch.mean(theta)
2. tf.map_fn()

顺便记录一下 tf.map_fn() 函数,通过官方文档以及查资料明白其就是一个遍历迭代最后stack的过程,但是复现过程中总是遇到了问题,这时候注意把源码中的元组的每个元素都看作为迭代器,分开计算结果就正确了。

R = tf.map_fn(my_R, (skew_true, skew_pred, angle_true, angle_pred), dtype=tf.float32)
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值