Improved Deep Metric Learning with Multi-class N-pair Loss Objective论文N-pair loss解读与实现

论文:NIPS2016 Improved Deep Metric Learning with Multi-class N-pair Loss Objective

距离度量学习旨在学习在嵌入空间能够保使得相似数据点靠的近,不相似数据离得远的一种数据嵌入表达技术。得益于深度学习的迅速发展,深度度量学习受到广泛关注。相比于标准的距离度量学习,深度学习能够学习到一种非线性的嵌入表达,这种嵌入表达配合contrastive loss 和triplet loss在人脸识别和图像检索领域取得巨大的成功。尽管取得了巨大的成功,但是这些框架经常遇到收敛缓慢甚至陷入局部最优解的困扰,这是因为在每次更新网络权重时候,这些损失函数仅仅考虑了一个负样本并没有将其他类的负样本距离考虑进来。当然,通过困难负样本挖掘可以缓解这一问题,但是在深度学习框架下使用这种技术会造成极大的计算负担。

为了解决上述问题,本文提出了一种 (N+1)-tuplet loss来同时优化一个正样本和N-1个负样本。当N为2时候,等价于triplet loss。当N很大时候,该损失也会承担非常大的计算负担。为了解决这个问题,本文提出了一种高效的Batch construction策略,仅仅使用2N个样本而不是N(N+1)个样本就能完成N个类别的优化。本文将这种(N+1)-tuplet loss和batch构造策略合起来一起叫做multi-class N-pair loss (N-pair-mc loss)。当样本类别书N比较小或者不是特别大的时候,N-pair loss已经同时考虑了N-1个负样本和一个正样本的距离优化,所以没必要进行困难负样本挖掘。但是,当N特别大的时候,也是需要进行的,同样的本文也是提出了一种困难负类挖掘技术(这儿我没去了解)。

回顾
首先回顾Contrastive loss 和 triplet loss
在这里插入图片描述
contrastive loss 和 triplet loss 的作用类似,都是想办法拉近同类样本距离,拉远异类样本距离。
triplet loss 的痛点在于每次只看一个负类的距离,没有考虑其他所有负类的情况,这就导致了在随机产生的数据对中,每一个数据对并不能有效的保证当前优化的方向能够拉远所有负类样本的距离,这就导致了往往训练过程中的收敛不稳定或者陷入局部最优。

N-pair loss的提出
N-pair loss 的提出就是解决这个问题的,出发点就是同时优化使用所有负类
在这里插入图片描述

由上面N-pair loss损失的定义可以看出,每次使用N-1个负类样本,一个正类样本。文中使用内积运算表示两个向量之间的距离(这一点个人理解可以类似余弦距离,距离越大说明两个向量越靠近,越小说明越远)。文中也阐述,当N=2的时候,是近似triplet loss的。
在这里插入图片描述

论文也表示,N-pair loss近似multi-class logistic loss(i.e.,softmax loss),这为实现提供了很大的便利。

N-pair loss的高校实现(理论)
在这里插入图片描述

作者也提到,直接优化N-pair loss 需要M*(N+1)个样本,其中batchsize为M,每次优化需要N+1个样本,显然直接这么搞是不行的,显存不答应。那么作者提出一个batch构造策略,每次数据是成对输入的(每次输入N2个),然后经过DL的映射,得到两个矩阵,一个叫anchors(维度为Nnum_featires,所有类的嵌入),一个叫positives(维度为N*num_featires,也是所有类的嵌入)。接着算anchors和positives转置的矩阵乘积,这一步很关键,因为这一步算的就是anchor和positives和 all negatives 的内积。那怎么知道哪些距离是需要缩小,哪些是需要放大的呢,这儿引入一个targets参数就行了,targets代表了每一对数据的类别。
代码实现

参考自:
tensorflow上面的实现https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/npairs_loss

def npairs_loss(labels, embeddings_anchor, embeddings_positive,
                reg_lambda=0.002, print_losses=False):
  """Computes the npairs loss.
  Npairs loss expects paired data where a pair is composed of samples from the
  same labels and each pairs in the minibatch have different labels. The loss
  has two components. The first component is the L2 regularizer on the
  embedding vectors. The second component is the sum of cross entropy loss
  which takes each row of the pair-wise similarity matrix as logits and
  the remapped one-hot labels as labels.
  See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf
  Args:
    labels: 1-D tf.int32 `Tensor` of shape [batch_size/2].
    embeddings_anchor: 2-D Tensor of shape [batch_size/2, embedding_dim] for the
      embedding vectors for the anchor images. Embeddings should not be
      l2 normalized.
    embeddings_positive: 2-D Tensor of shape [batch_size/2, embedding_dim] for the
      embedding vectors for the positive images. Embeddings should not be
      l2 normalized.
    reg_lambda: Float. L2 regularization term on the embedding vectors.
    print_losses: Boolean. Option to print the xent and l2loss.
  Returns:
    npairs_loss: tf.float32 scalar.
  """
  # pylint: enable=line-too-long
  # Add the regularizer on the embedding.
  reg_anchor = math_ops.reduce_mean(
      math_ops.reduce_sum(math_ops.square(embeddings_anchor), 1))
  reg_positive = math_ops.reduce_mean(
      math_ops.reduce_sum(math_ops.square(embeddings_positive), 1))
  l2loss = math_ops.multiply(
      0.25 * reg_lambda, reg_anchor + reg_positive, name='l2loss')

  # Get per pair similarities.
  similarity_matrix = math_ops.matmul(
      embeddings_anchor, embeddings_positive, transpose_a=False,
      transpose_b=True)

  # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor.
  lshape = array_ops.shape(labels)
  assert lshape.shape == 1
  labels = array_ops.reshape(labels, [lshape[0], 1])

  labels_remapped = math_ops.to_float(
      math_ops.equal(labels, array_ops.transpose(labels)))
  labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keepdims=True)

  # Add the softmax loss.
  xent_loss = nn.softmax_cross_entropy_with_logits(
      logits=similarity_matrix, labels=labels_remapped)
  xent_loss = math_ops.reduce_mean(xent_loss, name='xentropy')

  if print_losses:
    xent_loss = logging_ops.Print(
        xent_loss, ['cross entropy:', xent_loss, 'l2loss:', l2loss])

  return l2loss + xent_loss

ChaofEWang的pytorch实现
https://github.com/ChaofWang/Npair_loss_pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.contrib.losses.python.metric_learning import metric_loss_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops

def cross_entropy(logits, target, size_average=True):
    if size_average:
        return torch.mean(torch.sum(- target * F.log_softmax(logits, -1), -1))
    else:
        return torch.sum(torch.sum(- target * F.log_softmax(logits, -1), -1))


class NpairLoss(nn.Module):
    """the multi-class n-pair loss"""
    def __init__(self, l2_reg=0.02):
        super(NpairLoss, self).__init__()
        self.l2_reg = l2_reg

    def forward(self, anchor, positive, target):
        '''  
        anchor and positve are pair data, which are from the same class and target indicate their class
        '''
        batch_size = anchor.size(0)
        target = target.view(target.size(0), 1)

        target = (target == torch.transpose(target, 0, 1)).float()
        target = target / torch.sum(target, dim=1, keepdim=True).float()

        logit = torch.matmul(anchor, torch.transpose(positive, 0, 1))
        loss_ce = cross_entropy(logit, target)
        l2_loss = torch.sum(anchor**2) / batch_size + torch.sum(positive**2) / batch_size

        loss = loss_ce + self.l2_reg*l2_loss*0.25
        return loss


class NpairsLossTest(test.TestCase):
  def testNpairs(self):
    with self.test_session():
      num_data = 16
      feat_dim = 5
      num_classes = 3
      reg_lambda = 0.02
      
      # 首先构造测试的数据:anchors,positives,labels
      # 其中anchors和positives代表着成对的数据,每一行(一对数据)取自同一个类,label代表对应成对数据的类别,
      # 一个标准的batch是有N(类别数)对的样本(当然也可以不是,比如N太大了)
      embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32)
      embeddings_positive = np.random.rand(num_data, feat_dim).astype(np.float32)


      labels = np.random.randint(0, num_classes, size=(num_data)).astype(np.float32)

      # Reshape labels to compute adjacency matrix.
      labels_reshaped = np.reshape(labels, (labels.shape[0], 1))

      # 这儿计算anchors和positives.T的内积,度量两两向量之间的距离
      similarity_matrix = np.matmul(embeddings_anchor, embeddings_positive.T)

      # 为了确定哪些距离需要拉近,哪些需要拉远,需要一个labels矩阵,
      # labels_remapped值为1的地方表示similarity_matrix对应地方的距离是需要拉近的(他们来自同一类),
      # 值为0的地方对应的距离需要拉远
      labels_remapped = np.equal(labels_reshaped, labels_reshaped.T).astype(np.float32)

      # 归一化同一行label
      labels_remapped /= np.sum(labels_remapped, axis=1, keepdims=True)

      # 论文中也提到用交叉熵的方式计算损失
      xent_loss = math_ops.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(
          logits=ops.convert_to_tensor(similarity_matrix),
          labels=ops.convert_to_tensor(labels_remapped))).eval()

      # Compute the loss in NP
      reg_term = np.mean(np.sum(np.square(embeddings_anchor), 1))
      reg_term += np.mean(np.sum(np.square(embeddings_positive), 1))
      reg_term *= 0.25 * reg_lambda

      loss_np = xent_loss + reg_term

      # Compute the loss in pytorch
      npairloss = NpairLoss()
      loss_tc = npairloss(
                anchor=torch.tensor(embeddings_anchor),
                positive=torch.tensor(embeddings_positive),
                target=torch.from_numpy(labels)
                )

      # Compute the loss in TF
      loss_tf = metric_loss_ops.npairs_loss(
          labels=ops.convert_to_tensor(labels),
          embeddings_anchor=ops.convert_to_tensor(embeddings_anchor),
          embeddings_positive=ops.convert_to_tensor(embeddings_positive),
          reg_lambda=reg_lambda)
      loss_tf = loss_tf.eval()

      print('pytorch version: ', loss_tc.numpy())
      print('numpy version: ',loss_np)
      print('tensorflow version: ',loss_tf)
      # self.assertAllClose(loss_np, loss_tf)

if __name__ == '__main__':
    NpairsLossTest().testNpairs()
  • 14
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值