Triplet loss(三元损失函数)
Triplet Loss是Google在2015年发表的FaceNet论文中提出的,论文原文见附录。Triplet Loss即三元组损失,我们详细来介绍一下。
Triplet Loss定义:最小化锚点和具有相同身份的正样本之间的距离,最小化锚点和具有不同身份的负样本之间的距离。
Triplet Loss的目标:Triplet Loss的目标是使得相同标签的特征在空间位置上尽量靠近,同时不同标签的特征在空间位置上尽量远离,同时为了不让样本的特征聚合到一个非常小的空间中要求对于同一类的两个正例和一个负例,负例应该比正例的距离至少远margin。如下图所示:
可以看出经过Triplet loss学习以后同类的Positive样本和Anchor的距离越来越近而不同类的Negative样本和Anchor的距离越来越远。
因为我们期望下式成立:
其中α就是上面提到的margin,就是样本容量为N的数据集的各种三元组。然后根据上式,Triplet Loss可以写成:
注意,上面式子右下角的那个+的含义是[]内的值大于0的时候,取该值为损失,小于0的时候,损失为0.
对应的针对三个样本的梯度计算公式为:
∂
L
∂
f
(
x
i
a
)
=
2
∗
(
f
(
x
i
a
)
−
f
(
x
i
p
)
)
−
2
∗
(
f
(
x
i
a
)
−
f
(
x
i
n
)
)
=
2
∗
(
f
(
x
i
n
)
−
f
(
x
i
p
)
)
∂
L
∂
f
(
x
i
p
)
=
2
∗
(
f
(
x
i
a
)
−
f
(
x
i
p
)
)
∗
(
−
1
)
=
2
∗
(
f
(
x
i
p
)
−
f
(
x
i
a
)
)
∂
L
∂
f
(
x
i
n
)
=
2
∗
(
f
(
x
i
a
)
−
f
(
x
i
n
)
)
∗
(
−
1
)
=
2
∗
(
f
(
x
i
a
)
−
f
(
x
i
p
)
)
\begin{aligned} \frac{\partial L}{\partial f\left(x_{i}^{a}\right)} &=2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right)-2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right)=2 *\left(f\left(x_{i}^{n}\right)-f\left(x_{i}^{p}\right)\right) \\ \frac{\partial L}{\partial f\left(x_{i}^{p}\right)} &=2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right) *(-1)=2 *\left(f\left(x_{i}^{p}\right)-f\left(x_{i}^{a}\right)\right) \\ \frac{\partial L}{\partial f\left(x_{i}^{n}\right)} &=2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right) *(-1)=2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right) \end{aligned}
∂f(xia)∂L∂f(xip)∂L∂f(xin)∂L=2∗(f(xia)−f(xip))−2∗(f(xia)−f(xin))=2∗(f(xin)−f(xip))=2∗(f(xia)−f(xip))∗(−1)=2∗(f(xip)−f(xia))=2∗(f(xia)−f(xin))∗(−1)=2∗(f(xia)−f(xip))
代码实现:
def triplet_loss(y_true, y_pred):
"""
Triplet Loss的损失函数
"""
anc, pos, neg = y_pred[:, 0:128], y_pred[:, 128:256], y_pred[:, 256:]
# 欧式距离
pos_dist = K.sum(K.square(anc - pos), axis=-1, keepdims=True)
neg_dist = K.sum(K.square(anc - neg), axis=-1, keepdims=True)
basic_loss = pos_dist - neg_dist + TripletModel.MARGIN
loss = K.maximum(basic_loss, 0.0)
print "[INFO] model - triplet_loss shape: %s" % str(loss.shape)
return loss