loss.py(三元组损失)

# -*- encoding: utf-8 -*-
"""
@File    : losses.py
@Time    : 2021-05-13 17:35
@Author  : XD
@Email   : gudianpai@qq.com
@Software: PyCharm
"""
from IPython import embed
import torch.nn as nn
import torch

class TripleLoss(nn.Module):
    def __init__(self, margin = 0.3):
        super(TripleLoss, self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin = margin)
    def forward(self,inputs, target):
        """
        Args:
        inputs: feature matrix with shape (batch_size, feat_dim)
        targets: ground truth labels with shape (num_classes)
        """
        n = inputs.size(0)
        #(a - b)^2 = a^2 + b^2 -2ab
        #pytorch和numpy都是点乘 [1,2,3] * [1,2,3] =[1,4,9]
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()

        #dist.addmm(1, -2, inputs, inputs.t())
        dist = dist - 2 *inputs@inputs.t()
        dist = dist.clamp(min = 1e-12).sqrt()  # for numerical stability
        # For each anchor, find the hardest positive and negative

        #mask = target.extend(n,n) == target.extend(n,n).t
        # 这句话可能会使得返回一个布尔类型的张量,但是我们需要一个..tensor

        mask = target.expand(n, n).eq(target.expand(n,n).t())
        dist_ap ,dist_an = [],[]
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))#dist[i][mask[i]].max() 类型为tensor
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)

        y = torch.zeros_like(dist_ap) #为什么需要这一行代码呢 大大的疑问
        #loss = self.ranking_loss(dist_ap, dist_an, y) 罗博士第一次写错了
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return loss
if __name__ == '__main__':
    target = [1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,6,6,6,6,7,7,7,7,8,8,8,8,]
    target = torch.tensor(target)
    feature = torch.eye(32,2048)
    print("feature:",feature)
    a = TripleLoss()
    print(a.forward(feature, target))




  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值