assert函数_ECCV2020 | SmoothAP:用于大规模图像检索的平滑损失函数,解决不可微问题...

点击上方“AI算法修炼营”,选择“星标”公众号

精选作品,第一时间送达

本文首发自:https://zhuanlan.zhihu.com/p/163413041

75623ce7b5bd4c8b2b42daa55ab1fdd0.png

论文地址:https://arxiv.org/abs/2007.12163

代码地址:https://github.com/Andrew-Brown1/Smooth_AP

视频讲解:https://www.bilibili.com/video/BV1dD4y1m7fA

一、简介

图像检索通常是,给定一个包含特定实例(例如特定目标、场景、建筑等)的查询图像,图像检索旨在从数据库图像中找到包含相同实例的图像。但由于不同图像的拍摄视角、光照、或遮挡情况不同,如何设计出能应对这些类内差异的有效且高效的图像检索算法仍是一项研究难题。

图像检索的典型流程 首先,设法从图像中提取一个合适的图像的表示向量。其次,对这些表示向量用欧式距离或余弦距离进行最近邻搜索以找到相似的图像。最后,可以使用一些后处理技术对检索结果进行微调。可以看出,决定一个图像检索算法性能的关键在于提取的图像表示的好坏。

不同于以往基于度量学习的损失函数,作者提出了基于优化排序的损失函数选择的优化对象是AP(Average Precision),但是AP是不可微的,所以提出了smooth AP,具体做法是写了AP估值计算后,将其中的不可微部分换成sigmoid函数。在Stanford Online products,VehicleID,INaturalist,VGGFace2 and IJB-C上做了实验,结果不错。结果示意:

71e1bce7eeca92daa91ee7e5971e983f.png

二、本文方法

Notations

  • 640?wx_fmt=svg :retrieval set

  • 640?wx_fmt=svg :query instance

  • 640?wx_fmt=svg640?wx_fmt=svg :positive and negative set

  • 640?wx_fmt=svg :cosine similarity。640?wx_fmt=svg640?wx_fmt=svg :positive and negative relevance score sets;640?wx_fmt=svg :query vector;640?wx_fmt=svg :vectorized retrieval set。

  • 640?wx_fmt=svg :AP。640?wx_fmt=svg :the rankings of the instance i。

  • 640?wx_fmt=svg :ranking R。640?wx_fmt=svg :an indicator function。

  • 640?wx_fmt=svg :a difference matrix

  • 640?wx_fmt=svg

优化AP其实是最小化 640?wx_fmt=svg ,就是排序时候不要让负样本排到正样本前面。

71747ec613af88453ed749ebbffefe57.png

Smooth AP

上面的 indicator function 不能被基于梯度的方法优化。

所以改为sigmoid:640?wx_fmt=svg640?wx_fmt=svg 是平滑系数。

AP的估计值重写为 :

640?wx_fmt=svg

损失函数为:640?wx_fmt=svg

此外还有三点分析。

  • 第一点是平滑系数越小,AP的估计值越接近真实AP,而越大的平滑系数会带来更大的操作空间,就是图二里求导后的曲线下方面积,可以提供更多的梯度信息。

  • 第二点是triplet loss 更像是度量损失而不是优化排序。

  • 第三点是相对于其他优化AP的方法 FastAP and Blackbox AP,本方法更简单,并且估计的更准。而且这俩方法和triplet loss 一样,可能更像度量损失。

class SmoothAP(torch.nn.Module):
    """PyTorch implementation of the Smooth-AP loss.
    implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
    the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
    have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
    e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
        labels = ( A, A, A, B, B, B, C, C, C)
    (the order of the classes however does not matter)
    For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
    mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
    same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
    Args:
        anneal : float
            the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
            results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
        batch_size : int
            the batch size being used during training.
        num_id : int
            the number of different classes that are represented in the batch.
        feat_dims : int
            the dimension of the input feature embeddings
    Shape:
        - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
        - Output: scalar
    Examples::
        >>> loss = SmoothAP(0.01, 60, 6, 256)
        >>> input = torch.randn(60, 256, requires_grad=True).cuda()
        >>> output = loss(input)
        >>> output.backward()
    """

    def __init__(self, anneal, batch_size, num_id, feat_dims):
        """
        Parameters
        ----------
        anneal : float
            the temperature of the sigmoid that is used to smooth the ranking function
        batch_size : int
            the batch size being used
        num_id : int
            the number of different classes that are represented in the batch
        feat_dims : int
            the dimension of the input feature embeddings
        """
        super(SmoothAP, self).__init__()

        assert(batch_size%num_id==0)

        self.anneal = anneal
        self.batch_size = batch_size
        self.num_id = num_id
        self.feat_dims = feat_dims

    def forward(self, preds):
        """Forward pass for all input predictions: preds - (batch_size x feat_dims) """


        # ------ differentiable ranking of all retrieval set ------
        # compute the mask which ignores the relevance score of the query to itself
        mask = 1.0 - torch.eye(self.batch_size) 
        mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
        # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
        sim_all = compute_aff(preds)
        sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
        # compute the difference matrix
        sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
        # pass through the sigmoid
        sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask.cuda()
        # compute the rankings
        sim_all_rk = torch.sum(sim_sg, dim=-1) + 1

        # ------ differentiable ranking of only positive set in retrieval set ------
        # compute the mask which only gives non-zero weights to the positive set
        xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
        pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))
        pos_mask = pos_mask.unsqueeze(dim=0).unsqueeze(dim=0).repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)
        # compute the relevance scores
        sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))
        sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(1, 1, int(self.batch_size / self.num_id), 1)
        # compute the difference matrix
        sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
        # pass through the sigmoid
        sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask.cuda()
        # compute the rankings of the positive set
        sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1

        # sum the values of the Smooth-AP for all instances in the mini-batch
        ap = torch.zeros(1).cuda()
        group = int(self.batch_size / self.num_id)
        for ind in range(self.num_id):
            pos_divide = torch.sum(sim_pos_rk[ind] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
            ap = ap + ((pos_divide / group) / self.batch_size)

        return (1-ap)

三、实验

数据集:

5a71757592c081b92786485175bcb48f.png

SOP结果和AP估计,可以看到平滑系数越小,估计越准:

6f1515dc2b31466d5daa2bb863a9c18f.png

VehicleID 和 INaturalist 数据集上的结果:

a5cde998f6f549e25c4d06d2e3f0b69f.png

人脸数据集上的结果:

e8c3efbc7e447b41c79066e7ffe8ea0b.png

消融实验。第二个表,越小的P意味着一个batch里面其他类别更多,负样本数更多,排序出现负样本出现在正样本前面的概率越大,更有利于网络学习。

5eafe6a78453384f97e6876144c91c6c.png

定性结果:

bd96764c9bee9c881aad7c0e7acce6b9.png

b8ba4c0ef989e3d8981d176ffe5a3d1b.png

d683d623cb9bcc2ce7bdfb9b6fa39535.png

1f7f6dc1e603bb677add28a1ec3820b9.png

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
可以使用 `tf.debugging.assert_equal` 函数来确保 `logits` 和 `labels` 的形状匹配。这个函数会检查两个张量的形状是否相同,如果不相同,则会抛出异常并停止程序的运行。下面是一个简单的例子: ```python import tensorflow as tf logits = tf.random.normal([64, 10]) labels = tf.random.uniform([64], maxval=10, dtype=tf.int32) tf.debugging.assert_equal(tf.shape(logits), tf.shape(labels)) ``` 在这个例子中,`logits` 的形状是 `[64, 10]`,`labels` 的形状是 `[64]`,我们使用 `tf.debugging.assert_equal` 函数来检查这两个张量的形状是否相同。如果这两个张量的形状不同,程序会抛出异常并停止运行。 在使用交叉熵损失函数训练神经网络时,可以在每个 batch 计算损失时加入这个检查,例如: ```python import tensorflow as tf model = tf.keras.Sequential([...]) # 定义模型 optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) # 定义优化器 loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) # 定义损失函数 for epoch in range(num_epochs): for step, (x_batch_train, y_batch_train) in enumerate(train_dataset): with tf.GradientTape() as tape: logits = model(x_batch_train, training=True) loss_value = loss_fn(y_batch_train, logits) tf.debugging.assert_equal(tf.shape(logits), tf.shape(y_batch_train)) # 检查形状是否匹配 gradients = tape.gradient(loss_value, model.trainable_weights) optimizer.apply_gradients(zip(gradients, model.trainable_weights)) ``` 在这个例子中,我们使用 `tf.debugging.assert_equal` 函数来确保 `logits` 和 `y_batch_train` 的形状匹配。如果形状不匹配,程序会抛出异常并停止运行。这样可以避免因为形状不匹配导致的训练错误,提高代码的鲁棒性。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值