计算两个R/T旋转平移矩阵的欧拉角和平移向量的误差

import torch
import numpy as np
from common.math.so3 import dcm2euler
from lib.benchmark_utils import to_array
from common.math_torch import se3
def compute_rot_trans_error1(pred_transforms, gt_transforms):
"""
Euler angles, Individual translation errors (Deep Closest Point convention)
根据给定的旋转平移矩阵计算欧拉角形式的旋转误差,以及计算平移向量的误差
@param pred_transforms: 预测的R/T矩阵
@param gt_transforms: 标签R/T矩阵
@return: r_mse:欧拉角二阶均方误差, r_mae: 欧拉角一阶平均绝对值误差,
to_array(t_mse):平移向量二阶均方误差, to_array(t_mae):平移向量一阶平均绝对值误差
"""
r_gt_euler_deg = dcm2euler(gt_transforms[:, :3, :3].numpy(), seq='xyz')
r_pred_euler_deg = dcm2euler(pred_transforms[:, :3, :3].numpy(), seq='xyz')
t_gt = gt_transforms[:, :3, 3]
t_pred = pred_transforms[:, :3, 3]
r_mse = np.mean((r_gt_euler_deg - r_pred_euler_deg) ** 2, axis=1)
r_mae = np.mean(np.abs(r_gt_euler_deg - r_pred_euler_deg), axis=1)
t_mse = torch.mean((t_gt - t_pred) ** 2, dim=1)
t_mae = torch.mean(torch.abs(t_gt - t_pred), dim=1)
return r_mse, r_mae, to_array(t_mse), to_array(t_mae)

def compute_rot_trans_error2(pred_transforms, gt_transforms):
"""
计算旋转误差和平移误差,其中旋转误差以旋转向量旋转角(角度,非弧度)形式返回,平移向量误差以二范数标量形式返回
# Rotation, translation errors (isotropic, i.e. doesn't depend on error
direction, which is more representative of the actual error)
@param pred_transforms: 预测的R/T矩阵
@param gt_transforms: 标签R/T矩阵
@return: residual_rotdeg:旋转向量形式下旋转角误差, residual_transmag:平移向量误差
"""
concatenated = se3.concatenate(se3.inverse(gt_transforms), pred_transforms)
rot_trace = concatenated[:, 0, 0] + concatenated[:, 1, 1] + concatenated[:, 2, 2]
residual_rotdeg = torch.acos(torch.clamp(0.5 * (rot_trace - 1), min=-1.0, max=1.0)) * 180.0 / np.pi
residual_transmag = concatenated[:, :, 3].norm(dim=-1)
return to_array(residual_rotdeg), to_array(residual_transmag)



if __name__ == '__main__':
pre_transforms = torch.randn((1,3,3))
pre_trans = torch.randn(1,3,1)
pre_transforms = torch.cat([pre_transforms, pre_trans], dim=-1)
gt_transforms = torch.eye(3).unsqueeze(0)
gt_transforms = torch.cat([gt_transforms, torch.randn(1,3,1)], dim=-1)
print(pre_transforms)
print(gt_transforms)
r_mse, r_mae, t_mse, t_mae = compute_rot_trans_error1(pre_transforms, pre_transforms)
residual_rotdeg, residual_transmag = compute_rot_trans_error2(pre_transforms, pre_transforms)
print(r_mse, r_mae, t_mse, t_mae)
print(residual_rotdeg, residual_transmag)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值