PyTorch triphard代码理解

pytorch 代码

首先上代码,这份实现非常简洁优雅!

class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining.
    
    Reference:
        Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
    
    Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.
    
    Args:
        margin (float, optional): margin for triplet. Default is 0.3.
    """
    
    def __init__(self, margin=0.3,global_feat, labels):
        super(TripletLoss, self).__init__()
        self.margin = margin
        # https://pytorch.org/docs/1.2.0/nn.html?highlight=marginrankingloss#torch.nn.MarginRankingLoss
        # 计算两个张量之间的相似度,两张量之间的距离>margin,loss 为正,否则loss 为 0
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)
 
    def forward(self, inputs, targets):
        """
        Args:
            inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).
            targets (torch.LongTensor): ground truth labels with shape (num_classes).
        """
        n = inputs.size(0)	# batch_size
        
        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs, inputs.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        
        # For each anchor, find the hardest positive and negative
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
            dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
        dist_ap = torch.cat(dist_ap)
        dist_an = torch.cat(dist_an)
        
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return loss

代码阅读

样本距离计算

看不懂没关系,先举个简单的例子:
假设(batch_size, feat_dim)为(3,4),再简单点,令inputs为:
[ 1 2 3 4 5 6 7 8 9 10 11 12 ] \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8\\ 9 & 10 & 11 & 12 \end{bmatrix} 159261037114812
那么第一行 ( 1 , 2 , 3 , 4 ) (1, 2 ,3 ,4) (1,2,3,4)就代表第一个样本(S1),第二行 ( 5 , 6 , 7 , 8 ) (5, 6, 7, 8) (5,6,7,8)代表第二个样本(S2),以此类推。看一下样本之间的距离(欧氏距离)分布,为方便,用<>表示距离:

  • <S1,S1> ( S 1 − S 1 ) 2 (S1- S1)^2 (S1S1)2,显然为0
  • <S1,S2> ( S 1 − S 2 ) 2 (S1- S2)^2 (S1S2)2 ( ( 1 − 5 ) 2 + ( 2 − 6 ) 2 + ( 3 − 7 ) 2 + ( 4 − 8 ) 2 ) = 64 ((1-5)^2+(2-6)^2+(3-7)^2+(4-8)^2) = 64 ((15)2+(26)2+(37)2+(48)2)=64
  • <S1,S3> ( S 1 − S 3 ) 2 (S1- S3)^2 (S1S3)2 256 256 256

距离分布(开方之后)可以表示为矩阵 D D D
[ 0 8 16 8 0 8 16 8 0 ] \begin{bmatrix} 0 & 8 & 16 \\ 8 & 0 &8 \\ 16 & 8 & 0 \end{bmatrix} 08168081680
其中 D ( i , j ) D(i,j) D(i,j)表示样本 i , j i,j i,j之间的距离。可以看到样本之间的距离计算看起来很简单,但是当特征维度比较大的时候,上面的这种最直接的计算方式就会效率低下(我猜的),又因为输入都是张量(矩阵)的形式,所以使用矩阵运算更为合理,那就来看看代码是的实现原理(Triplet-Loss原理及其实现、应用):

因为 ( a − b ) 2 = a 2 − 2 a b + b 2 (a−b)^2=a^2−2ab+b^2 (ab)2=a22ab+b2, 而矩阵相乘 e m b e d d i n g s × e m b e d d i n g s . T embeddings×embeddings.T embeddings×embeddings.T中不仅包含了 a ∗ b a*b ab的值,同时对角线上是向量平方的值,所以可以直接使用矩阵计算。

首先输入和例子一样:

inputs = torch.arange(1,13).view(3,4).float()
>>> inputs
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])

假设上述3个样本可以表示为 ( S 1 S 2 S 3 ) \bigl( \begin{matrix} S1 \\ S2 \\ S3 \end{matrix} \bigr) (S1S2S3) S 1 即 为 S1即为 S1(1, 2, 3, 4), S 1 2 S1^2 S12则是 ( 1 2 + 2 2 + 3 2 + 4 2 ) = 30 (1^2+2^2+3^2+4^2) = 30 (12+22+32+42)=30 S 2 , S 3 S2, S3 S2,S3同理。

n = inputs.size(0)	# n = 3,为batch_size
# 每个数平方后, 进行sum(保持行数n不变),再扩展为(n,n)
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
>>> dist
tensor([[ 30.,  30.,  30.],
        [174., 174., 174.],
        [446., 446., 446.]])

那么上述的操作结果其实就是:
[ S 1 2 S 1 2 S 1 2 S 2 2 S 2 2 S 2 2 S 3 2 S 3 2 S 3 2 ] \begin{bmatrix} S1^2 & S1^2 & S1^2 \\ S2^2 & S2^2 & S2^2 \\ S3^2 & S3^2 & S3^2 \end{bmatrix} S12S22S32S12S22S32S12S22S32

# 这样每个dis[i][j]代表的是样本i与样本j的平方的和
dist = dist + dist.t()
>>> dist
tensor([[ 60., 204., 476.],
        [204., 348., 620.],
        [476., 620., 892.]])

同理,上述操作后,结果为(注意对角线):
[ S 1 2 + S 1 2 S 1 2 + S 2 2 S 1 2 + S 3 2 S 2 2 + S 1 2 S 2 2 + S 2 2 S 2 2 + S 3 2 S 3 2 + S 1 2 S 3 2 + S 2 2 S 3 2 + S 3 2 ] \begin{bmatrix} S1^2+ S1^2 & S1^2+S2^2 & S1^2+S3^2 \\ S2^2+ S1^2 & S2^2+S2^2 & S2^2+S3^2 \\ S3^2+ S1^2 & S3^2+S2^2 & S3^2+S3^2 \end{bmatrix} S12+S12S22+S12S32+S12S12+S22S22+S22S32+S22S12+S32S22+S32S32+S32

addmm_()的用法:torch — PyTorch master documentation
其实很简单,在这儿就是:
1 ⋅ d i s t − 2 ( i n p u t @ i n p u t . t ( ) ) (1) 1·dist - 2(input @ input.t()) \tag{1} 1dist2(input@input.t())(1)

dist.addmm_(1, -2, inputs, inputs.t())
>>> dist
tensor([[  0.,  64., 256.],
        [ 64.,   0.,  64.],
        [256.,  64.,   0.]])

2 ∗ ( i n p u t @ i n p u t . t ( ) ) 2*(input @ input.t()) 2(input@input.t())是这样的:
[ 2 S 1 2 2 S 1 ∗ S 2 2 S 1 ∗ S 3 2 S 2 ∗ S 1 2 S 2 2 2 S 2 ∗ S 3 2 S 3 ∗ S 1 2 S 3 ∗ S 2 2 S 3 2 ] \begin{bmatrix} 2S1^2 & 2S1*S2 & 2S1*S3 \\ 2S2*S1 &2 S2^2 & 2S2*S3 \\ 2S3*S1 & 2S3*S2 & 2S3^2 \end{bmatrix} 2S122S2S12S3S12S1S22S222S3S22S1S32S2S32S32
是不是看起来有完全平方式那个味了,没错,代入(1)式之后的表达是这样的:
[ 0 S 1 2 + S 2 2 − 2 S 1 ∗ S 2 S 1 2 + S 3 2 − 2 S 1 ∗ S 3 S 2 2 + S 1 2 − 2 S 2 ∗ S 1 0 S 2 2 + S 3 2 − 2 S 2 ∗ S 3 S 3 2 + S 1 2 − 2 S 3 ∗ S 1 S 3 2 + S 2 2 − 2 S 3 ∗ S 2 0 ] \begin{bmatrix} 0 & S1^2+S2^2- 2S1*S2& S1^2+S3^2-2S1*S3 \\ S2^2+ S1^2-2S2*S1 & 0 & S2^2+S3^2-2S2*S3 \\ S3^2+ S1^2-2S3*S1 & S3^2+S2^2-2S3*S2 & 0 \end{bmatrix} 0S22+S122S2S1S32+S122S3S1S12+S222S1S20S32+S222S3S2S12+S322S1S3S22+S322S2S30
也就是:
[ ( S 1 − S 1 ) 2 ( S 1 − S 2 ) 2 ( S 1 − S 3 ) 2 ( S 2 − S 1 ) 2 ( S 2 − S 2 ) 2 ( S 2 − S 3 ) 2 ( S 3 − S 1 ) 2 ( S 3 − S 2 ) 2 ( S 3 − S 3 ) 2 ] \begin{bmatrix} (S1- S1)^2 & (S1- S2)^2 & (S1- S3)^2 \\ (S2- S1)^2 & (S2- S2)^2 & (S2- S3)^2 \\ (S3- S1)^2 & (S3- S2)^2 & (S3- S3)^2 \end{bmatrix} (S1S1)2(S2S1)2(S3S1)2(S1S2)2(S2S2)2(S3S2)2(S1S3)2(S2S3)2(S3S3)2

下面进行开方,clamp做简单数值处理(为了数值稳定性):小于min参数的dist元素值由min值取代。

Triplet-Loss原理及其实现、应用
根号下不能为0,0开根号没有问题的,但是梯度反向传播就会导致无穷大

dist = dist.clamp(min=1e-12).sqrt()
>>> dist
tensor([[1.0000e-06, 8.0000e+00, 1.6000e+01],
        [8.0000e+00, 1.0000e-06, 8.0000e+00],
        [1.6000e+01, 8.0000e+00, 1.0000e-06]])

那么到这里,样本对之间距离计算就到位了,下面还要进行的是困难样本挖掘。

困难样本挖掘

For each anchor, find the hardest positive and negative
上一节操作输出的张量dist里面存储着各个样本之间的距离,首先看看targets:

targets: ground truth labels with shape (num_classes)
其实就是样本对应的标签

首先获取mask,类似掩模,# 这里 m a s k [ i ] [ j ] = 1 mask[i][j] = 1 mask[i][j]=1 代表 i i i j j jlabel相同(属于同一类别), m a s k [ i ] [ j ] = 0 mask[i][j] = 0 mask[i][j]=0则相反。mask用于后面提取正样本和负样本。

mask = targets.expand(n, n).eq(targets.expand(n, n).t())

下面分别提取出正样本和负样本,对每个样本,在上面生成的距离矩阵中:

  • 先过滤掉和它不同类别的样本对应的距离,剩下的就是和它同一类别的positive,然后再在剩下的positive中找到距离值最大的,就是我们需要的hard positive
  • 寻找negative也是同理
for i in range(n):
	dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
    dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))

拼接为新的tensor,并计算相应的损失,这儿没有难度。至于为什么要这么算,看一下 loss说明就知道了。

dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)

y = torch.ones_like(dist_an) 
loss = self.ranking_loss(dist_an, dist_ap, y)

自己去弄个简单例子一步一步跑一边就明白大概了。。。

  • 35
    点赞
  • 70
    收藏
    觉得还不错? 一键收藏
  • 18
    评论
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值