Triplet loss(三元损失函数)

Triplet loss(三元损失函数)

Triplet Loss是Google在2015年发表的FaceNet论文中提出的,论文原文见附录。Triplet Loss即三元组损失,我们详细来介绍一下。

Triplet Loss定义:最小化锚点和具有相同身份的正样本之间的距离,最小化锚点和具有不同身份的负样本之间的距离。
Triplet Loss的目标:Triplet Loss的目标是使得相同标签的特征在空间位置上尽量靠近,同时不同标签的特征在空间位置上尽量远离,同时为了不让样本的特征聚合到一个非常小的空间中要求对于同一类的两个正例和一个负例,负例应该比正例的距离至少远margin。如下图所示:

可以看出经过Triplet loss学习以后同类的Positive样本和Anchor的距离越来越近而不同类的Negative样本和Anchor的距离越来越远。

在这里插入图片描述
因为我们期望下式成立:
在这里插入图片描述
其中α就是上面提到的margin,就是样本容量为N的数据集的各种三元组。然后根据上式,Triplet Loss可以写成:
在这里插入图片描述
注意,上面式子右下角的那个+的含义是[]内的值大于0的时候,取该值为损失,小于0的时候,损失为0.

对应的针对三个样本的梯度计算公式为:
∂ L ∂ f ( x i a ) = 2 ∗ ( f ( x i a ) − f ( x i p ) ) − 2 ∗ ( f ( x i a ) − f ( x i n ) ) = 2 ∗ ( f ( x i n ) − f ( x i p ) ) ∂ L ∂ f ( x i p ) = 2 ∗ ( f ( x i a ) − f ( x i p ) ) ∗ ( − 1 ) = 2 ∗ ( f ( x i p ) − f ( x i a ) ) ∂ L ∂ f ( x i n ) = 2 ∗ ( f ( x i a ) − f ( x i n ) ) ∗ ( − 1 ) = 2 ∗ ( f ( x i a ) − f ( x i p ) ) \begin{aligned} \frac{\partial L}{\partial f\left(x_{i}^{a}\right)} &=2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right)-2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right)=2 *\left(f\left(x_{i}^{n}\right)-f\left(x_{i}^{p}\right)\right) \\ \frac{\partial L}{\partial f\left(x_{i}^{p}\right)} &=2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right) *(-1)=2 *\left(f\left(x_{i}^{p}\right)-f\left(x_{i}^{a}\right)\right) \\ \frac{\partial L}{\partial f\left(x_{i}^{n}\right)} &=2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right) *(-1)=2 *\left(f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right) \end{aligned} f(xia)Lf(xip)Lf(xin)L=2(f(xia)f(xip))2(f(xia)f(xin))=2(f(xin)f(xip))=2(f(xia)f(xip))(1)=2(f(xip)f(xia))=2(f(xia)f(xin))(1)=2(f(xia)f(xip))

代码实现:

def triplet_loss(y_true, y_pred):
        """
        Triplet Loss的损失函数
        """

        anc, pos, neg = y_pred[:, 0:128], y_pred[:, 128:256], y_pred[:, 256:]

        # 欧式距离
        pos_dist = K.sum(K.square(anc - pos), axis=-1, keepdims=True)
        neg_dist = K.sum(K.square(anc - neg), axis=-1, keepdims=True)
        basic_loss = pos_dist - neg_dist + TripletModel.MARGIN

        loss = K.maximum(basic_loss, 0.0)

        print "[INFO] model - triplet_loss shape: %s" % str(loss.shape)
        return loss
  • 10
    点赞
  • 71
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值