旋转目标框angle损失计算:

该文章介绍了基于高斯Wasserstein距离(GWD)的旋转目标检测算法,通过计算两个高斯分布的Wasserstein距离来精确匹配旋转矩形框。同时,文章中还包含了将xywhr坐标转换为xy和方差的函数以及KLD损失函数的实现,用于评估预测和目标分布之间的差异。
摘要由CSDN通过智能技术生成

1.旋转矩形框表示为高斯分布:

GWD:基于高斯Wasserstein距离的旋转目标检测 | ICML 2021 - 知乎 (zhihu.com)

2.两个高斯分布Wasserstein distance距离:

(30条消息) Wasserstein距离_RayRings的博客-CSDN博客

3.

def xy_wh_r_2_xy_sigma(xywhr):
    #print(xywhr.type())
    _shape = xywhr.shape
    #print(_shape)
    assert _shape[-1] == 5
    xy = xywhr[..., :2]
    wh = xywhr[..., 2:4].clamp(min=1e-7, max=1e7).reshape(-1, 2)
    r = xywhr[..., 4]
    cos_r = torch.cos(r)
    sin_r = torch.sin(r)
    R = torch.stack((cos_r, -sin_r, sin_r, cos_r), dim=-1).reshape(-1, 2, 2)
    S = 0.5 * torch.diag_embed(wh)## 由 diag 变为三维 3*3

    sigma = R.bmm(S.square()).bmm(R.permute(0, 2, 1)).reshape(
        _shape[:-1] + (2, 2))#torch.suqare()返回带有输入元素平方的新张量。 矩阵的批量相乘,支持TensorFloat32数据的操作。#要求:input 和 mat2 必须是 3-D 张量,每个张量都包含相同数量的矩阵。如果input的维度是( b × n × m ) (b\times n\times m)(b×n×m),mat2维度是( b × m × p ) (b\times m \times p)(b×m×p),那么返回的结果out就是:( b × n × p ) (b\times n \times p)(b×n×p),那么有:

    return xy, sigma
def kld_loss(pred, target, fun='log1p', tau=1.0, alpha=1.0, sqrt=True):
    # todo
    xy_p, Sigma_p = pred
    xy_t, Sigma_t = target
    

    _shape = xy_p.shape

    xy_p = xy_p.reshape(-1, 2)
    xy_t = xy_t.reshape(-1, 2)
    Sigma_p = Sigma_p.reshape(-1, 2, 2)
    Sigma_t = Sigma_t.reshape(-1, 2, 2)

    Sigma_p_inv = torch.stack((Sigma_p[..., 1, 1], -Sigma_p[..., 0, 1],
                               -Sigma_p[..., 1, 0], Sigma_p[..., 0, 0]),
                              dim=-1).reshape(-1, 2, 2)
    Sigma_p_inv = Sigma_p_inv / Sigma_p.det().unsqueeze(-1).unsqueeze(-1)

    dxy = (xy_p - xy_t).unsqueeze(-1)
    xy_distance = 0.5 * dxy.permute(0, 2, 1).bmm(Sigma_p_inv).bmm(
        dxy).view(-1)

    whr_distance = 0.5 * Sigma_p_inv.bmm(
        Sigma_t).diagonal(dim1=-2, dim2=-1).sum(dim=-1)

    Sigma_p_det_log = Sigma_p.det().log()
    Sigma_t_det_log = Sigma_t.det().log()
    whr_distance = whr_distance + 0.5 * (Sigma_p_det_log - Sigma_t_det_log)
    whr_distance = whr_distance - 1
    distance = (xy_distance / (alpha * alpha) + whr_distance)
    if sqrt:
        distance = distance.clamp(0).sqrt()

    distance = distance.reshape(_shape[:-1])
    return postprocess(distance, fun=fun, tau=tau)
def xy_wh_r_2_xy_sigma(xywhr):

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值