# 先知道,左边点云和右边点云哪些点是有必要统计损失的,
# 换句话说,拿到所有匹配对中的点,不在匹配对中的点,不算损失。
# 因为有的点 压根不在匹配对里,那这种点的特征就没必要算损失了。
row_masks = (torch.gt(pos_masks.sum(-1), 0) & torch.gt(neg_masks.sum(-1), 0)).detach()
col_masks = (torch.gt(pos_masks.sum(-2), 0) & torch.gt(neg_masks.sum(-2), 0)).detach()
# 这里是计算妙用,没什么高级的操作
# 目的就是,为所有GT的匹配对,构建权重,权重是(特征相似度 - 0.1)
pos_weights = feat_dists - 1e5 * (~pos_masks).float() # mask the non-positive
pos_weights = pos_weights - pos_optimal # mask the uninformative positive
pos_weights = torch.maximum(torch.zeros_like(pos_weights), pos_weights)
# 把overlap的信息利用进来,
# 如果你这一对,overlap很大很大,那必须给你一个大的权重,因为你很重要,importance很大
if pos_scales is not None:
pos_weights = pos_weights * pos_scales
pos_weights = pos_weights.detach()
# 为所有非GT的匹配对,构建权重! 注意!
# 这种非GT的匹配对显然是很多很多的。
# 但是有的,一整行其实都不会参与最后的损失计算,例如第156行,如果第156个点,不是在匹配对里,那这一行的数值都没什么用了。
neg_weights = feat_dists + 1e5 * (~neg_masks).float() # mask the non-negative
neg_weights = neg_optimal - neg_weights # mask the uninformative negative
neg_weights = torch.maximum(torch.zeros_like(neg_weights), neg_weights)
# 这里跳过
if neg_scales is not None:
neg_weights = neg_weights * neg_scales
neg_weights = neg_weights.detach()
# log_scale 是常数,(feat_dists - pos_margin) 是优化目标,即希望feats_dist接近pos_margin
# 别忘了我们的 pos_weights, 这是经由 overlap的大小+ 特征相似性得到的。
loss_pos_row = torch.logsumexp(log_scale * (feat_dists - pos_margin) * pos_weights, dim=-1)
loss_pos_col = torch.logsumexp(log_scale * (feat_dists - pos_margin) * pos_weights, dim=-2)
loss_neg_row = torch.logsumexp(log_scale * (neg_margin - feat_dists) * neg_weights, dim=-1)
loss_neg_col = torch.logsumexp(log_scale * (neg_margin - feat_dists) * neg_weights, dim=-2)
loss_row = F.softplus(loss_pos_row + loss_neg_row) / log_scale
loss_col = F.softplus(loss_pos_col + loss_neg_col) / log_scale
# 我们只取出 在匹配对中的行和列。
loss = (loss_row[row_masks].mean() + loss_col[col_masks].mean()) / 2
return loss
GeoTransformer的Coarse matching loss
最新推荐文章于 2024-04-26 09:36:52 发布