姿态估计0-08:DenseFusion(6D姿态估计)-源码解析(4)-PoseNet网络loss详解(重点篇)

以下链接是个人关于DenseFusion(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:a944284742相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号 海量资源。 \color{blue}{ 海量资源}。 海量资源

姿态估计0-00:DenseFusion(6D姿态估计)-目录-史上最新无死角讲解https://blog.csdn.net/weixin_43013761/article/details/103053585

代码引导

在tools/train.py中,我们可以看到如下代码:

	 points, choose, img, target, model_points, idx = data
     # points:由深度图计算出来的点云,该点云数据以摄像头主轴参考坐标
     # choose:所选择点云的索引,[bs, 1, 500]
     # img:通过box剪切下来的RGB图像
     # target:根据model_points点云信息,以及旋转偏移矩阵转换过的点云信息[bs,500,3]
     # model_points:目标初始帧(模型)对应的点云信息[bs,500,3]
     # idx:目标物体的序列编号

	 pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
	 # 进行预测获得,获得预测的姿态,姿态预测之前的特征向量
	 # pred_r: 预测的旋转参数[bs, 500, 4]
	 # pred_t: 预测的偏移参数[bs, 500, 3]
	 # pred_c: 预测的置信度[bs, 500, 1]
	
	
	 # 对结果进行评估,计算loss
	 loss, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start)

通过上小结,我们可以知道estimator是对姿态进行估算,其函数返回的emb就是当前需要估算姿态图像(RBG)抽取出来的特征向量。从data中迭代出来的target,已经从数据预处理章节中推导出来,他是由model_points(目标物体第一帧点云)根据标准的参数转换到当前帧的点云数据,可以理解为他就是一个标签ground truth。points表示的是当前帧的点云数据,注意target的参考基准为model_points,points的参考基准是摄像头。refine_start标记已经是否开始了refine网络的训练。好了这样大家就明白网络的输入了,我们进入lib/loss.py来看看其上criterion函数的实现(大致浏览一下,后面本人有代码领读的-麻烦的事情交给我就好,把你的时间用在刀刃上)。

from torch.nn.modules.loss import _Loss
from torch.autograd import Variable
import torch
import time
import numpy as np
import torch.nn as nn
import random
import torch.backends.cudnn as cudnn
from lib.knn.__init__ import KNearestNeighbor


def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):
    """
    :param pred_r: 预测的旋转参数[bs, 500, 4],相对于摄像头
    :param pred_t: 预测的偏移参数[bs, 500, 3],相对于摄像头
    :param pred_c: 预测的置信度参数[bs, 500, 1],相对于摄像头
    :param target: 目标姿态,也就是预测图片,通过标准偏移矩阵,结合model_points求得图片对应得点云数据[bs,500,3],这里点云数据,就是学习的目标数据
    :param model_points:目标模型的点云数据-第一帧[bs,500,3]
    :param idx:随机训练的一个索引
    :param points:由深度图计算出来的点云,也就是说该点云数据以摄像头为参考坐标
    :param refine:标记是否已经开始训练refine网络
    :param num_point_mesh:500
    :param sym_list:对称模型的序列号
    """
    print('='*50)
    #print('pred_r.shape: {0}', format(pred_r.shape))
    # print('target.shape: {0}', format(target.shape))
    # print('model_points.shape: {0}', format(model_points.shape))
    # print('points.shape: {0}', format(points.shape))

    knn = KNearestNeighbor(1)
    # [bs, 500, 1]
    bs, num_p, _ = pred_c.size()

    # 把预测的旋转矩阵进行正则化
    pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
    #print('pred_r.shape: {0}', format(pred_r.shape))

    # base[bs,500, 4] -->[500, 3, 3],把预测的旋转参数,转化为旋转矩阵矩阵,
    base = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1),\
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)
    #print('base.shape: {0}', format(base.shape))

    # 把相对于摄像头的偏移矩阵记录下来
    ori_base = base

    # [3, 3, 500]
    base = base.contiguous().transpose(2, 1).contiguous()

    # 复制num_p=500次,[bs,1,500,3]-->[500,500,3],这里的复制操作,主要是因为每个ground truth(target)点云,
    # 需要与所有的predicted点云做距离差,
    model_points = model_points.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    #print('model_points.shape: {0}', format(model_points.shape))

    # 复制num_p=500次,[bs,1,500,3]-->[500,500,3],这里的复制操作,主要是因为每个ground truth(target)点云,
    # 需要与所有的predicted点云做距离差,
    target = target.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    #print('target.shape: {0}', format(target.shape))


    # 把初始的目标点云(已经通过标准的pose进行了变换)记录下来
    ori_target = target
    pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
    # 把原始预测的偏移矩阵记录下来,这里的t是相对摄像头的
    ori_t = pred_t

    # 当前帧的点云,结合深度图计算而来,也就是说该点云信息是以摄像头为参考目标
    points = points.contiguous().view(bs * num_p, 1, 3)
    pred_c = pred_c.contiguous().view(bs * num_p)

    # 为批量矩阵相乘,model_points与旋转矩阵相乘加上偏移矩阵,得到当前帧对应的点云姿态,该点云的姿态是以model_points为参考的
    #pred[500,500,3]
    pred = torch.add(torch.bmm(model_points, base), points + pred_t)
    #print('pred.shape: {0}', format(pred.shape))

    #print('refine: {0}', format(refine))
    if not refine:
        # 如果是对称的物体
        if idx[0].item() in sym_list:

            # [500,500,3]-->[3,250000]
            target = target[0].transpose(1, 0).contiguous().view(3, -1)
            #print('target.shape: {0}', format(target.shape))

            # [500,500,3]-->[3,250000]
            pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
            #print('pred.shape: {0}', format(pred.shape))

            # [1, 1, 250000],target的每个点云和pred的所有点云进行对比,找到每个target点云与pred的所有点云,距离最近点云的索引(pred)
            inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
            #print('inds.shape: {0}', format(inds.shape))

            # [3, 250000],从target点云中,根据计算出来的min索引,全部挑选出来
            target = torch.index_select(target, 1, inds.view(-1).detach() - 1)
            #print('target.shape: {0}', format(target.shape))

            # [500, 500, 3]
            target = target.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()

            # [500, 500, 3]
            pred = pred.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()

    # 求得预测点云和目标点云的平均距离(每个点云),按照论文,把置信度和点云距离关联起来
    dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)
    loss = torch.mean((dis * pred_c - w * torch.log(pred_c)), dim=0)
    

    # 下面的操作都是为refine模型训练的准备
    pred_c = pred_c.view(bs, num_p)

    # which_max表示的索引下标,即找到置信度最高的哪个下标
    how_max, which_max = torch.max(pred_c, 1)
    print(how_max,which_max)
    #print('which_max.shape: {0}', format(which_max.shape))

    dis = dis.view(bs, num_p)

    # 获得最好的偏移矩阵,这里的t是相对model_points的
    t = ori_t[which_max[0]] + points[which_max[0]]
    points = points.view(1, bs * num_p, 3)

    # 求得500中置信度最高的旋转矩阵,相对于摄像头的
    ori_base = ori_base[which_max[0]].view(1, 3, 3).contiguous()
    ori_t = t.repeat(bs * num_p, 1).contiguous().view(1, bs * num_p, 3)

    # 根据预测最好的旋转矩阵,求得新的当前帧对应的点云,注意这里是一个减号的操作,并且其中的ori_t是相对于摄像头的
    # (但是ori_t和ori_base都是预测出来的,就是返回去肯定存在偏差的)
    new_points = torch.bmm((points - ori_t), ori_base).contiguous()

    new_target = ori_target[0].view(1, num_point_mesh, 3).contiguous()
    ori_t = t.repeat(num_point_mesh, 1).contiguous().view(1, num_point_mesh, 3)

    # 根据预测最好的旋转矩阵,求得新的当前帧对应的点云,注意这里是一个减号的操作,并且其中的ori_t是相对于摄像头的
    # (但是ori_t和ori_base都是预测出来的,就是返回去肯定存在偏差的,这里的偏差因为new_target是标准的,所以应该少一些)
    new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

    # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item())
    del knn

    # loss:根据每个点云计算出来的平均loss
    # 对应预测置信度度最高,target与预测点云之间的最小距离
    # new_points:根据最好的旋转矩阵,求得当前帧的点云
    # new_target:就是根据model_points求得的标椎点云
    return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()


class Loss(_Loss):

    def __init__(self, num_points_mesh, sym_list):
        super(Loss, self).__init__(True)
        self.num_pt_mesh = num_points_mesh
        self.sym_list = sym_list

    def forward(self, pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine):

        return loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, self.num_pt_mesh, self.sym_list)


代码领读

1.首先我们我们来说说knn = KNearestNeighbor(1),他作用其实很简单,就是为了加快计算的速度,所以作者使用的是编译成的C++库,看了我论文介绍的朋友,大家应该就知道,在训练对称物体的时候,每个target点云,需要在由估算姿态pose转化而来的pred cloud所有点云中,找到和他距离最近的哪个点云,然后把所有的距离加起来求得平均距离,当做loss。

2.在论文解读中,我多次强调,其是一个稠密(像素/点云级别)的姿态估算,也就是说,他会对每个点云/像素都做一个估算,并且都带有其相应的置信度。所以能够看到这样如下代码:

    # which_max表示的索引下标,即找到置信度最高的哪个下标
    how_max, which_max = torch.max(pred_c, 1)
    
    # 获得最好的偏移矩阵
    t = ori_t[which_max[0]] + points[which_max[0]]
    points = points.view(1, bs * num_p, 3)    

    # 求得500中置信度最高的旋转矩阵
    ori_base = ori_base[which_max[0]].view(1, 3, 3).contiguous()
    ori_t = t.repeat(bs * num_p, 1).contiguous().view(1, bs * num_p, 3)

这样我们就能找到最好的旋转和偏移矩阵了。

3. 重难点 \color{red}{3.重难点} 3.重难点,这里就是思维的转折点了,如果没有认真看我论文翻译的(你在原论文中,是看不到这样详细的讲解的),请通过下面链接翻到后面的部分,读一下:
姿态估计0-03:DenseFusion(6D姿态估计)-白话给你讲论文-翻译无死角(1)
读了之后,大家再来看下面的代码:

    # 根据预测最好的旋转矩阵,求得新的当前帧对应的点云,注意这里是一个减号的操作,并且其中的ori_t是相对于摄像头的
    # (但是ori_t和ori_base都是预测出来的,就是返回去肯定存在偏差的)
    new_points = torch.bmm((points - ori_t), ori_base).contiguous()


    # 根据预测最好的旋转矩阵,求得新的当前帧对应的点云,注意这里是一个减号的操作,并且其中的ori_t是相对于摄像头的
    # (但是ori_t和ori_base都是预测出来的,就是返回去肯定存在偏差的,这里的偏差因为new_target是标准的,所以应该少一些)
    new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

首先可以看到,他们的变换是一个逆的过程,正的过程是这样的 X 1 = X 2 ⋅ R + t X_{1}=X_{2}·R+t X1=X2R+t,但是上面的公式,很显然是这样的 X 2 = ( X 1 − t ) ⋅ R X_{2}=(X_{1}-t)·R X2=(X1t)R。别给我说不理解啊,担心我打你。
这个逆的操作,分别对points和target操作了,他们偏移的ori_t是相对于摄像头的。也就是都逆转到了摄像头那个空间,很显然,target逆转之后(得到new_target)应该是要比points逆转(得到new_points)的效果是要好的。虽然逆转的参数pose一样,但是new_target本人的点云数据更加准确。现在关键点来了,我们可不可以做一个网络出来,让points偏转回去尽量的接近target偏转回去的空间。当然可以,也就是在偏转的过程中,重新学习一个姿态new_pose,如果两个人都拿初始估算出来的pose,他们怎么偏转都没用啊。所以啊,他们要重新配对,points学习新的姿态new_pose,target还是利用初始估算出来的pose。

大概就是这个意思,
new_target点云说,同样利用估算出来带有差错的pose,我逆操作的效果就是比你好,new_points 你就是垃圾。
new_points 就不服了,同样拿错误的pose,凭什么你做的就要比我好。主要原因在于points本身点云数据就有差错(比如分割-遮挡原因导致),还要加上有差错的pose,就是错上加错了,怎么比得过new_target呢? 最后,他就打算偷偷的去学习,学习到一个新的new_pose,用这个new_pose去弥补和new_target之间的差距 ,也就是把new_target当作目标,尽量的去接近他(超越是不现实的)。原理就是利用new_pose去抵消本身存在的误差。

那么,是怎么去学习这个新new_pose,这里就涉及到PoseRefineNet了,也就是下篇博客要讲解的内容。如果看得明白,麻烦各位朋友给个赞啊。

在这里插入图片描述

  • 27
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江南才尽,年少无知!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值