时间:2023年1月4日
一、目的背景
学习了论文《Rethinking Rotated Object Detection with Gaussian Wasserstein Distance Loss》,想看一下GWD损失函数在代码层面是怎么实现的。于是就在mmrotate代码库中把GWD损失函数代码单独摘了出来,写了一个小demo测一下。
二、代码
# 测试gwd
import torch
def gwd_loss(pred, target, fun='log1p', tau=1.0, alpha=1.0, normalize=True):
xy_p, Sigma_p = pred
xy_t, Sigma_t = target
xy_distance = (xy_p - xy_t).square().sum(dim=-1)
whr_distance = Sigma_p.diagonal(dim1=-2, dim2=-1).sum(dim=-1)
whr_distance = whr_distance + Sigma_t.diagonal(
dim1=-2, dim2=-1).sum(dim=-1)
_t_tr = (Sigma_p.bmm(Sigma_t)).diagonal(dim1=-2, dim2=-1).sum(dim=-1)
_t_det_sqrt = (Sigma_p.det() * Sigma_t.det()).clamp(1e-7).sqrt()
whr_distance = whr_distance + (-2) * (
(_t_tr + 2 * _t_det_sqrt).clamp(1e-7).sqrt())
distance = (xy_distance + alpha * alpha * whr_distance).clamp(1e-7).sqrt()
if normalize:
scale = 2 * (
_t_det_sqrt.clamp(1e-7).sqrt().clamp(1e-7).sqrt()).clamp(1e-7)
distance = distance / scale
return postprocess(distance, fun=fun, tau=tau)
def postprocess(distance, fun='log1p', tau=1.0):
"""Convert distance to loss.
Args:
distance (torch.Tensor)
fun (str, optional): The function applied to distance.
Defaults to 'log1p'.
tau (float, optional): Defaults to 1.0.
Returns:
loss (torch.Tensor)
"""
if fun == 'log1p':
distance = torch.log1p(distance)
elif fun == 'sqrt':
distance = torch.sqrt(distance.clamp(1e-7))
elif fun == 'none':
pass
else:
raise ValueError(f'Invalid non-linear function {fun}')
if tau >= 1.0:
return 1 - 1 / (tau + distance)
else:
return distance
def xy_wh_r_2_xy_sigma(xywhr):
"""Convert oriented bounding box to 2-D Gaussian distribution.
Args:
xywhr (torch.Tensor): rbboxes with shape (N, 5).
Returns:
xy (torch.Tensor): center point of 2-D Gaussian distribution
with shape (N, 2).
sigma (torch.Tensor): covariance matrix of 2-D Gaussian distribution
with shape (N, 2, 2).
"""
_shape = xywhr.shape
assert _shape[-1] == 5
xy = xywhr[..., :2]
wh = xywhr[..., 2:4].clamp(min=1e-7, max=1e7).reshape(-1, 2)
r = xywhr[..., 4]
cos_r = torch.cos(r)
sin_r = torch.sin(r)
R = torch.stack((cos_r, -sin_r, sin_r, cos_r), dim=-1).reshape(-1, 2, 2)
S = 0.5 * torch.diag_embed(wh)
sigma = R.bmm(S.square()).bmm(R.permute(0, 2,
1)).reshape(_shape[:-1] + (2, 2))
return xy, sigma
if __name__ == '__main__':
print(torch.cuda.is_available())
obbox1 = torch.tensor([100, 100, 10, 50, 0]).float().reshape(-1, 5).cuda()
obbox2 = torch.tensor([100, 100, 10, 50, 90]).float().reshape(-1, 5).cuda()
pred = xy_wh_r_2_xy_sigma(obbox1)
target = xy_wh_r_2_xy_sigma(obbox2)
loss = gwd_loss(pred, target)
print(loss)
三、GWD loss代码在mmrotate中位置
/mmrotate/models/losses/gaussian_dist_loss.py