GWD loss 损失函数的示例

时间:2023年1月4日

一、目的背景
学习了论文《Rethinking Rotated Object Detection with Gaussian Wasserstein Distance Loss》,想看一下GWD损失函数在代码层面是怎么实现的。于是就在mmrotate代码库中把GWD损失函数代码单独摘了出来,写了一个小demo测一下。

二、代码

# 测试gwd
import torch

def gwd_loss(pred, target, fun='log1p', tau=1.0, alpha=1.0, normalize=True):
    xy_p, Sigma_p = pred
    xy_t, Sigma_t = target

    xy_distance = (xy_p - xy_t).square().sum(dim=-1)

    whr_distance = Sigma_p.diagonal(dim1=-2, dim2=-1).sum(dim=-1)
    whr_distance = whr_distance + Sigma_t.diagonal(
        dim1=-2, dim2=-1).sum(dim=-1)

    _t_tr = (Sigma_p.bmm(Sigma_t)).diagonal(dim1=-2, dim2=-1).sum(dim=-1)
    _t_det_sqrt = (Sigma_p.det() * Sigma_t.det()).clamp(1e-7).sqrt()
    whr_distance = whr_distance + (-2) * (
        (_t_tr + 2 * _t_det_sqrt).clamp(1e-7).sqrt())

    distance = (xy_distance + alpha * alpha * whr_distance).clamp(1e-7).sqrt()

    if normalize:
        scale = 2 * (
            _t_det_sqrt.clamp(1e-7).sqrt().clamp(1e-7).sqrt()).clamp(1e-7)
        distance = distance / scale

    return postprocess(distance, fun=fun, tau=tau)


def postprocess(distance, fun='log1p', tau=1.0):
    """Convert distance to loss.

    Args:
        distance (torch.Tensor)
        fun (str, optional): The function applied to distance.
            Defaults to 'log1p'.
        tau (float, optional): Defaults to 1.0.

    Returns:
        loss (torch.Tensor)
    """
    if fun == 'log1p':
        distance = torch.log1p(distance)
    elif fun == 'sqrt':
        distance = torch.sqrt(distance.clamp(1e-7))
    elif fun == 'none':
        pass
    else:
        raise ValueError(f'Invalid non-linear function {fun}')

    if tau >= 1.0:
        return 1 - 1 / (tau + distance)
    else:
        return distance


def xy_wh_r_2_xy_sigma(xywhr):
    """Convert oriented bounding box to 2-D Gaussian distribution.

    Args:
        xywhr (torch.Tensor): rbboxes with shape (N, 5).

    Returns:
        xy (torch.Tensor): center point of 2-D Gaussian distribution
            with shape (N, 2).
        sigma (torch.Tensor): covariance matrix of 2-D Gaussian distribution
            with shape (N, 2, 2).
    """
    _shape = xywhr.shape
    assert _shape[-1] == 5
    xy = xywhr[..., :2]
    wh = xywhr[..., 2:4].clamp(min=1e-7, max=1e7).reshape(-1, 2)
    r = xywhr[..., 4]
    cos_r = torch.cos(r)
    sin_r = torch.sin(r)
    R = torch.stack((cos_r, -sin_r, sin_r, cos_r), dim=-1).reshape(-1, 2, 2)
    S = 0.5 * torch.diag_embed(wh)

    sigma = R.bmm(S.square()).bmm(R.permute(0, 2,
                                            1)).reshape(_shape[:-1] + (2, 2))
    
    return xy, sigma

if __name__ == '__main__':
    print(torch.cuda.is_available())
    obbox1 = torch.tensor([100, 100, 10, 50, 0]).float().reshape(-1, 5).cuda()
    obbox2 = torch.tensor([100, 100, 10, 50, 90]).float().reshape(-1, 5).cuda()
    pred = xy_wh_r_2_xy_sigma(obbox1)
    target = xy_wh_r_2_xy_sigma(obbox2)
    loss = gwd_loss(pred, target)
    print(loss)

三、GWD loss代码在mmrotate中位置

/mmrotate/models/losses/gaussian_dist_loss.py
  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
中软3762.q-gwd集中器下行软件是一款用于控制和管理中软3762.q-gwd集中器的软件。下行软件通过下发指令和配置参数,实现对集中器的远程控制和监测。其主要功能如下: 1. 集中器参数配置:可以通过下行软件对集中器的基本参数进行配置,包括通信参数、时间同步、网络设置等。通过设定合适的参数,可以确保集中器与上级系统之间的通信正常和准确。 2. 负载管理:下行软件可以对集中器的负载进行管理,可以设置负载的分配比例、控制负载的开关状态等。通过合理的负载管理,可以提高集中器的工作效率,确保系统的稳定性和可靠性。 3. 数据采集和传输:下行软件可以实现对集中器的实时数据采集和传输,包括电流、电压、功率因数等相关参数。通过实时监测和传输数据,可以及时发现和解决系统中的问题,提高系统的运行效率。 4. 告警处理:下行软件能够及时接收并处理来自集中器的告警信息,包括电源故障、通信故障、超出设定范围等。可以通过软件设置告警的级别和处理方式,在出现故障或异常时,及时采取相应的措施,保障系统的正常运行。 5. 远程升级:下行软件支持远程对集中器进行固件升级和软件升级,可以实现对集中器的新功能添加和性能优化。通过持续的升级,可以不断提高集中器的性能和功能,满足不同需求。 总之,中软3762.q-gwd集中器下行软件是一款功能强大的软件,能够对集中器进行综合性的控制和管理,提高系统的运行效率和稳定性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值