pytorch 代码
首先上代码,这份实现非常简洁优雅!
class TripletLoss(nn.Module):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Imported from `<https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py>`_.
Args:
margin (float, optional): margin for triplet. Default is 0.3.
"""
def __init__(self, margin=0.3,global_feat, labels):
super(TripletLoss, self).__init__()
self.margin = margin
# https://pytorch.org/docs/1.2.0/nn.html?highlight=marginrankingloss#torch.nn.MarginRankingLoss
# 计算两个张量之间的相似度,两张量之间的距离>margin,loss 为正,否则loss 为 0
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, inputs, targets):
"""
Args:
inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).
targets (torch.LongTensor): ground truth labels with shape (num_classes).
"""
n = inputs.size(0) # batch_size
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, inputs, inputs.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
# For each anchor, find the hardest positive and negative
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = torch.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
return loss
代码阅读
样本距离计算
看不懂没关系,先举个简单的例子:
假设(batch_size, feat_dim)为(3,4),再简单点,令inputs
为:
[
1
2
3
4
5
6
7
8
9
10
11
12
]
\begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8\\ 9 & 10 & 11 & 12 \end{bmatrix}
⎣⎡159261037114812⎦⎤
那么第一行
(
1
,
2
,
3
,
4
)
(1, 2 ,3 ,4)
(1,2,3,4)就代表第一个样本(S1),第二行
(
5
,
6
,
7
,
8
)
(5, 6, 7, 8)
(5,6,7,8)代表第二个样本(S2),以此类推。看一下样本之间的距离(欧氏距离)分布,为方便,用<>
表示距离:
<S1,S1>
: ( S 1 − S 1 ) 2 (S1- S1)^2 (S1−S1)2,显然为0<S1,S2>
: ( S 1 − S 2 ) 2 (S1- S2)^2 (S1−S2)2, ( ( 1 − 5 ) 2 + ( 2 − 6 ) 2 + ( 3 − 7 ) 2 + ( 4 − 8 ) 2 ) = 64 ((1-5)^2+(2-6)^2+(3-7)^2+(4-8)^2) = 64 ((1−5)2+(2−6)2+(3−7)2+(4−8)2)=64<S1,S3>
: ( S 1 − S 3 ) 2 (S1- S3)^2 (S1−S3)2, 256 256 256- …
距离分布(开方之后)可以表示为矩阵
D
D
D:
[
0
8
16
8
0
8
16
8
0
]
\begin{bmatrix} 0 & 8 & 16 \\ 8 & 0 &8 \\ 16 & 8 & 0 \end{bmatrix}
⎣⎡08168081680⎦⎤
其中
D
(
i
,
j
)
D(i,j)
D(i,j)表示样本
i
,
j
i,j
i,j之间的距离。可以看到样本之间的距离计算看起来很简单,但是当特征维度比较大的时候,上面的这种最直接的计算方式就会效率低下(我猜的),又因为输入都是张量(矩阵)的形式,所以使用矩阵运算更为合理,那就来看看代码是的实现原理(Triplet-Loss原理及其实现、应用):
因为 ( a − b ) 2 = a 2 − 2 a b + b 2 (a−b)^2=a^2−2ab+b^2 (a−b)2=a2−2ab+b2, 而矩阵相乘 e m b e d d i n g s × e m b e d d i n g s . T embeddings×embeddings.T embeddings×embeddings.T中不仅包含了 a ∗ b a*b a∗b的值,同时对角线上是向量平方的值,所以可以直接使用矩阵计算。
首先输入和例子一样:
inputs = torch.arange(1,13).view(3,4).float()
>>> inputs
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
假设上述3个样本可以表示为 ( S 1 S 2 S 3 ) \bigl( \begin{matrix} S1 \\ S2 \\ S3 \end{matrix} \bigr) (S1S2S3), S 1 即 为 S1即为 S1即为(1, 2, 3, 4), S 1 2 S1^2 S12则是 ( 1 2 + 2 2 + 3 2 + 4 2 ) = 30 (1^2+2^2+3^2+4^2) = 30 (12+22+32+42)=30, S 2 , S 3 S2, S3 S2,S3同理。
n = inputs.size(0) # n = 3,为batch_size
# 每个数平方后, 进行sum(保持行数n不变),再扩展为(n,n)
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
>>> dist
tensor([[ 30., 30., 30.],
[174., 174., 174.],
[446., 446., 446.]])
那么上述的操作结果其实就是:
[
S
1
2
S
1
2
S
1
2
S
2
2
S
2
2
S
2
2
S
3
2
S
3
2
S
3
2
]
\begin{bmatrix} S1^2 & S1^2 & S1^2 \\ S2^2 & S2^2 & S2^2 \\ S3^2 & S3^2 & S3^2 \end{bmatrix}
⎣⎡S12S22S32S12S22S32S12S22S32⎦⎤
# 这样每个dis[i][j]代表的是样本i与样本j的平方的和
dist = dist + dist.t()
>>> dist
tensor([[ 60., 204., 476.],
[204., 348., 620.],
[476., 620., 892.]])
同理,上述操作后,结果为(注意对角线):
[
S
1
2
+
S
1
2
S
1
2
+
S
2
2
S
1
2
+
S
3
2
S
2
2
+
S
1
2
S
2
2
+
S
2
2
S
2
2
+
S
3
2
S
3
2
+
S
1
2
S
3
2
+
S
2
2
S
3
2
+
S
3
2
]
\begin{bmatrix} S1^2+ S1^2 & S1^2+S2^2 & S1^2+S3^2 \\ S2^2+ S1^2 & S2^2+S2^2 & S2^2+S3^2 \\ S3^2+ S1^2 & S3^2+S2^2 & S3^2+S3^2 \end{bmatrix}
⎣⎡S12+S12S22+S12S32+S12S12+S22S22+S22S32+S22S12+S32S22+S32S32+S32⎦⎤
addmm_()的用法:torch — PyTorch master documentation
其实很简单,在这儿就是:
1
⋅
d
i
s
t
−
2
(
i
n
p
u
t
@
i
n
p
u
t
.
t
(
)
)
(1)
1·dist - 2(input @ input.t()) \tag{1}
1⋅dist−2(input@input.t())(1)
dist.addmm_(1, -2, inputs, inputs.t())
>>> dist
tensor([[ 0., 64., 256.],
[ 64., 0., 64.],
[256., 64., 0.]])
2
∗
(
i
n
p
u
t
@
i
n
p
u
t
.
t
(
)
)
2*(input @ input.t())
2∗(input@input.t())是这样的:
[
2
S
1
2
2
S
1
∗
S
2
2
S
1
∗
S
3
2
S
2
∗
S
1
2
S
2
2
2
S
2
∗
S
3
2
S
3
∗
S
1
2
S
3
∗
S
2
2
S
3
2
]
\begin{bmatrix} 2S1^2 & 2S1*S2 & 2S1*S3 \\ 2S2*S1 &2 S2^2 & 2S2*S3 \\ 2S3*S1 & 2S3*S2 & 2S3^2 \end{bmatrix}
⎣⎡2S122S2∗S12S3∗S12S1∗S22S222S3∗S22S1∗S32S2∗S32S32⎦⎤
是不是看起来有完全平方式那个味了,没错,代入(1)式之后的表达是这样的:
[
0
S
1
2
+
S
2
2
−
2
S
1
∗
S
2
S
1
2
+
S
3
2
−
2
S
1
∗
S
3
S
2
2
+
S
1
2
−
2
S
2
∗
S
1
0
S
2
2
+
S
3
2
−
2
S
2
∗
S
3
S
3
2
+
S
1
2
−
2
S
3
∗
S
1
S
3
2
+
S
2
2
−
2
S
3
∗
S
2
0
]
\begin{bmatrix} 0 & S1^2+S2^2- 2S1*S2& S1^2+S3^2-2S1*S3 \\ S2^2+ S1^2-2S2*S1 & 0 & S2^2+S3^2-2S2*S3 \\ S3^2+ S1^2-2S3*S1 & S3^2+S2^2-2S3*S2 & 0 \end{bmatrix}
⎣⎡0S22+S12−2S2∗S1S32+S12−2S3∗S1S12+S22−2S1∗S20S32+S22−2S3∗S2S12+S32−2S1∗S3S22+S32−2S2∗S30⎦⎤
也就是:
[
(
S
1
−
S
1
)
2
(
S
1
−
S
2
)
2
(
S
1
−
S
3
)
2
(
S
2
−
S
1
)
2
(
S
2
−
S
2
)
2
(
S
2
−
S
3
)
2
(
S
3
−
S
1
)
2
(
S
3
−
S
2
)
2
(
S
3
−
S
3
)
2
]
\begin{bmatrix} (S1- S1)^2 & (S1- S2)^2 & (S1- S3)^2 \\ (S2- S1)^2 & (S2- S2)^2 & (S2- S3)^2 \\ (S3- S1)^2 & (S3- S2)^2 & (S3- S3)^2 \end{bmatrix}
⎣⎡(S1−S1)2(S2−S1)2(S3−S1)2(S1−S2)2(S2−S2)2(S3−S2)2(S1−S3)2(S2−S3)2(S3−S3)2⎦⎤
下面进行开方,clamp
做简单数值处理(为了数值稳定性):小于min
参数的dist
元素值由min
值取代。
Triplet-Loss原理及其实现、应用
根号下不能为0,0开根号没有问题的,但是梯度反向传播就会导致无穷大
dist = dist.clamp(min=1e-12).sqrt()
>>> dist
tensor([[1.0000e-06, 8.0000e+00, 1.6000e+01],
[8.0000e+00, 1.0000e-06, 8.0000e+00],
[1.6000e+01, 8.0000e+00, 1.0000e-06]])
那么到这里,样本对之间距离计算就到位了,下面还要进行的是困难样本挖掘。
困难样本挖掘
For each anchor, find the hardest positive and negative
上一节操作输出的张量dist
里面存储着各个样本之间的距离,首先看看targets:
targets: ground truth labels with shape (num_classes)
其实就是样本对应的标签
首先获取mask,类似掩模,# 这里
m
a
s
k
[
i
]
[
j
]
=
1
mask[i][j] = 1
mask[i][j]=1 代表
i
i
i 和
j
j
j 的label
相同(属于同一类别),
m
a
s
k
[
i
]
[
j
]
=
0
mask[i][j] = 0
mask[i][j]=0则相反。mask用于后面提取正样本和负样本。
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
下面分别提取出正样本和负样本,对每个样本,在上面生成的距离矩阵中:
- 先过滤掉和它不同类别的样本对应的距离,剩下的就是和它同一类别的
positive
,然后再在剩下的positive
中找到距离值最大的,就是我们需要的hard positive
- 寻找
negative
也是同理
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
拼接为新的tensor,并计算相应的损失,这儿没有难度。至于为什么要这么算,看一下 loss说明就知道了。
dist_ap = torch.cat(dist_ap)
dist_an = torch.cat(dist_an)
y = torch.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
自己去弄个简单例子一步一步跑一边就明白大概了。。。
完