【DenseFusion代码详解】主干网络loss计算

DenseFusion系列代码全讲解目录:【DenseFusion系列目录】代码全讲解+可视化+计算评估指标_Panpanpan!的博客-CSDN博客

这些内容均为个人学习记录,欢迎大家提出错误一起讨论一起学习!


该部分是对主干网络部分的loss进行计算。代码位置在lib/loss.py

什么时候进行loss的计算呢?

train/test模式下,对数据进行预处理之后,随机选取了num_points个点云(这里LineMOD数据集为500个,YCB数据集为1000个,文中以LineMOD数据集为主),获取了点云、image crop、choose等等,然后训练网络为每个像素回归旋转、平移、置信度,这个时候,就需要对网络输出的姿态计算loss。

train.py中对其的使用过程为:

from lib.loss import Loss #第26行-首先import
criterion = Loss(num_points_mesh, sym_list) #第108行-初始化
loss, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start) #第140行-进行forward过程求解loss

首先来看Loss类的定义:

from torch.nn.modules.loss import _Loss
class Loss(_Loss):

    def __init__(self, num_points_mesh, sym_list):  #初始化
        super(Loss, self).__init__(True)
        self.num_pt_mesh = num_points_mesh  #点数
        self.sym_list = sym_list #对称物体编号

    def forward(self, pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine):

        return loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, self.num_pt_mesh, self.sym_list) #计算loss

Loss类继承了torch.nn.modules.loss类,用于定义自己的损失函数。首先进行初始化,参数有mesh点数以及对称物体编号,在上述train.py的 criterion = Loss(num_points_mesh, sym_list) 中实现初始化,定义criterion。然后在训练过程中,使用criterion()的时候调用forward函数,括号中的参数信息为:

  • pred_r:预测的旋转R,每个像素都有一个,大小为torch.Size([1, 500, 4]),4是四元数表示
  • pred_t:预测的平移t,torch.Size([1, 500, 3])
  • pred_c:预测的置信度c,torch.Size([1, 500, 1])
  • target:目标点云,torch.Size([1, 500, 3]) 
  • model_points:模型第一帧的点云,torch.Size([1, 500, 3])
  • idx:类别编号,torch.Size([1, 1])
  • points:筛选的500个点云,torch.Size([1, 500, 3])
  • w:平衡超参数,默认值0.015,正则化,用于平衡置信度
  • refine:是否已经开始refine过程,True/False

forward函数中调用的计算loss的函数loss_calculation,下面一行一行分析。

首先,传入的参数除了上述之外加了初始化的两个参数:

def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):
    knn = KNearestNeighbor(1)  #KNN算法
    bs, num_p, _ = pred_c.size() 
    pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))

第一行定义了KNN算法,为了处理对称物体,后续再详细介绍。然后获取pred_c的大小,bs为1,num_p为500,然后对pred_r进行标准化,torch.norm(pred_r, dim=2) 用来对pred_r在最后一维上求L2范数。接着:

    base = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1),\
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)

这一大段就是求旋转矩阵R,DenseFusion中使用的是四元数(常用的四元数、欧拉角等)来表示旋转矩阵,网络回归出的是4个数值,现在要把它们转换成原始的9个数值,公式如下:

上述求base的过程就是该公式的实现,base的大小为torch.Size([500, 3, 3]) 。

    ori_base = base #旋转
    base = base.contiguous().transpose(2, 1).contiguous() #转置/逆
    model_points = model_points.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3) #把model_points复制500份
    target = target.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3) 
    ori_target = target
    pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
    ori_t = pred_t
    points = points.contiguous().view(bs * num_p, 1, 3)
    pred_c = pred_c.contiguous().view(bs * num_p)

    pred = torch.add(torch.bmm(model_points, base), points + pred_t)

这里把model_points和target都复制了500份,为了对每个像素的预测结果都进行计算,大小从torch.Size([1, 500, 3])变换到torch.Size([500, 500, 3])。这些操作都是重组变量的结构。最后,计算预测的模型pred,这里重点理解的是最后一行代码:pred = torch.add(torch.bmm(model_points, base), points + pred_t) 将model_points与base相乘,然后加上偏移矩阵,加上points,得到预测的模型点。这里我想了很久,为什么用model_points为基准?为什么还要加上points?下面说说我个人的看法:

首先,model_points是第一帧相机坐标系(xc0,yc0,zc0)下的点云,target也是由model_points经过标准R|t转换而来的另一视角i相机坐标系(xci,yci,zci)下的点,也就是说标准R|t是(xc0,yc0,zc0)和(xci,yci,zci)之间的转换,那么所求的base,也就是网络回归出的旋转,也是这两个坐标系之间的转换,而网络回归出的平移,不是绝对的平移,而是与points之间的差值,加上points之后,才是绝对的平移,而这个绝对的平移也是上述两个相机坐标系的转换。

另外一个理解是,可以把model_points的相机坐标系视为与世界坐标系对齐,就直接把model_points作为世界坐标系下的点,后面的每一帧都视为不同的相机坐标系,姿态就是世界坐标系与不同的相机坐标系之间的转换,这样理解可能更方便,虽然我并没有找到关于世界坐标系的描述。(个人理解,如果有不对欢迎指正)

    if not refine:
        if idx[0].item() in sym_list:
            target = target[0].transpose(1, 0).contiguous().view(3, -1)
            pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
            inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
            target = torch.index_select(target, 1, inds.view(-1).detach() - 1)
            target = target.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()
            pred = pred.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()

如果还没有开始refine,且物体是对称物体,那么就计算ADD-S:

 计算真实模型上的点和预测模型上的所有点之间距离的最小值,这一步是由knn实现的,返回inds为最小距离点的索引,然后选取对应索引的点重新构成target。

    dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)
    loss = torch.mean((dis * pred_c - w * torch.log(pred_c)), dim=0)

 如果不是对称物体则计算ADD:

计算真实模型上的点和预测模型上的对应点之间的距离,求和再取平均,总的loss就是每个像素的loss与其置信度相乘,然后加了一个平衡参数w:

为什么要加入后面那一项?如果只看前面那一项,那么默认置信度越大的像素所提供的loss就越大,但这不是我们想要的结果,置信度越大,应该说明这个像素更加重要,预测的结果更加准确,而不是损失越大,因此加入了-log(ci),取负对数之后,置信度越大,这一项就越小,因此总loss就会减少,w就是为了平衡这两项之间的关系。

    pred_c = pred_c.view(bs, num_p)
    how_max, which_max = torch.max(pred_c, 1)
    dis = dis.view(bs, num_p)

把pred_c置信度和dis都转换成torch .Size([1, 500]),how_max,which_max为置信度最大的值和索引。

    t = ori_t[which_max[0]] + points[which_max[0]]
    points = points.view(1, bs * num_p, 3)

 这里找到置信度最大的像素预测的平移ori_t和该像素对应的点云points,相加得到绝对的平移t(正如上述所说的一样)。

    ori_base = ori_base[which_max[0]].view(1, 3, 3).contiguous()
    ori_t = t.repeat(bs * num_p, 1).contiguous().view(1, bs * num_p, 3)
    new_points = torch.bmm((points - ori_t), ori_base).contiguous()

ori_base就是最大置信度像素预测的旋转矩阵,ori_t是将绝对平移t复制了500个便于之后的计算。对于最后一行的理解,需要看看公式,新的点P2由P1转换而来: P_{2} = P_{1}\cdot R + t,那么P_{1}=(P_{2}-t)\cdot R^{-1},正交矩阵的逆等于它的转置,而ori_base保存的是没有转置之前的base,那么这里就是对points进行逆转操作,这个new_points是为了作为后续refine过程的输入。

    new_target = ori_target[0].view(1, num_point_mesh, 3).contiguous()
    ori_t = t.repeat(num_point_mesh, 1).contiguous().view(1, num_point_mesh, 3)
    new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

    del knn
    return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()

 一样地,用target逆转得到new_target,最后返回loss,最大置信度像素的距离值,new_points和new_target。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Panpanpan!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值