Subclass-balancing Contrastive Learning for Long-tailed Recognition草稿版

Supervised contrastive learning.

核心公式解析

L S B C L = − ∑ i = 1 N ( 1 ∣ M ~ i ∣ ∑ z p ∈ M ~ i log ⁡ exp ⁡ ( z i ⋅ z p ⊤ / τ 1 ) ∑ z a ∈ V ~ i exp ⁡ ( z i ⋅ z a ⊤ / τ 1 ) + β 1 ∣ P ~ i ∣ − ∣ M i ∣ ∑ z p ∈ P ~ i / M i log ⁡ exp ⁡ ( z i ⋅ z p ⊤ / τ 2 ) ∑ z a ∈ V ~ i / M i exp ⁡ ( z i ⋅ z a ⊤ / τ 2 ) ) \begin{align*} & \mathcal{L}_{S B C L}=-\sum_{i=1}^{N}\left(\frac{1}{\left|\tilde{M}_{i}\right|} \sum_{z_{p} \in \tilde{M}_{i}} \log \frac{\exp \left(z_{i} \cdot z_{p}^{\top} / \tau_{1}\right)}{\sum_{z_{a} \in \tilde{V}_{i}} \exp \left(z_{i} \cdot z_{a}^{\top} / \tau_{1}\right)}\right. \\ & \left.+\beta \frac{1}{\left|\tilde{P}_{i}\right|-\left|M_{i}\right|} \sum_{z_{p} \in \tilde{P}_{i} / M_{i}} \log \frac{\exp \left(z_{i} \cdot z_{p}^{\top} / \tau_{2}\right)}{\sum_{z_{a} \in \tilde{V}_{i} / M_{i}} \exp \left(z_{i} \cdot z_{a}^{\top} / \tau_{2}\right)}\right) \tag{4} \end{align*} LSBCL=i=1N M~i 1zpM~ilogzaV~iexp(ziza/τ1)exp(zizp/τ1)+β P~i Mi1zpP~i/MilogzaV~i/Miexp(ziza/τ2)exp(zizp/τ2) (4)

公式4里面的负号是针对里面的两项的
所以,它有两个意思:

  • term1: 针对相同簇标签的实例之间的对比损失。这部分损失函数在计算时,比较了当前实例 x i x_i xi和同一簇中的正样本 z p z_p zp之间的相关性。具体而言,它通过计算它们的内积,并初一一个温度参数 τ 1 \tau_1 τ1来衡量它们的相似度,然后利用softmax函数来计算它们之间的相对概率。整个损失函数的目标是使同一簇中的实例在特征空间中更加接近。

  • term2: 针对不同簇中的相同类别标签实例之间的对比损失。这部分损失函数考虑了当前实例 x i x_i xi和其他不在同一簇中的同一类别标签的正样本 z p z_p zp之间的相关性。与第一部分不同的是,这里通过将同一类别标签的实例从损失计算中排除来确保实例的多样性。同样,这部分也利用内积和softmax函数来计算相对概率

  • 超参数 β \beta β在公式中用于平衡这两部分损失。它可以调整两种损失对模型训练的影响程度,以便更好地优化特征空间。第二部分的损失函数的目的是促使不同簇中的同一类别标签实例之间在特征空间中更加相似,以帮助模型更好地学习到数据的结构和特征。

综上所述,这个公式的目的是通过结合簇标签和类别标签的对比损失来优化模型的特征表示,使得相同簇内的实例更加相似,同时确保不同簇中的同一类别标签实例之间也能保持一定的相似度。

在公式4的第二部分中,通过将同一类别标签的实例从损失计算中排除,即只考虑不在同一簇中的同一类别标签的正样本,从而促使模型在学习时注重实例的多样性。具体来说,这种设计可以确保模型不会过于集中在同一簇内的实例之间的相似性,而是更加关注不同簇中的同一类别标签实例之间的相似性。

通过排除同一簇中的实例,模型被迫在学习时关注不同簇中的实例之间的关系,从而更好地学习到数据的全局结构和特征。这种方法有助于提高模型的泛化能力,使其能够更好地区分不同类别的实例,并增强对数据的表征能力。因此,通过排除同一簇中的实例,可以确保模型在学习时考虑到不同类别标签之间的多样性,从而提高模型的性能和表现。

也就是说“通过将同一类别标签的实例从损失计算中排除,即只考虑不在同一簇中的同一类别标签的正样本,”这样可以使得模型更加关注同一class的不同层次结构,而不是同一簇。把关注重心放在同一类,不同簇的层级结构上去。

是的,您理解的非常正确。通过将同一类别标签的实例从损失计算中排除,即只考虑不在同一簇中的同一类别标签的正样本,有助于使模型更加关注同一类别的不同层次结构,而不是仅限于同一簇内的实例。这样的设计可以帮助模型更好地理解和学习类别之间的层级结构和关系,从而提高模型对数据的表征能力和泛化能力。

通过重点关注同一类别标签的不同簇中的实例之间的关系,模型可以更好地捕捉类别之间的多样性和差异性,从而更好地区分不同类别。这种方法有助于模型学习到更加丰富和全面的数据特征,提高模型的性能和泛化能力。因此,将关注重心放在同一类别的不同层次结构上,而不仅限于同一簇内的实例,可以帮助模型更好地理解数据的类别结构和特征,从而提高模型的训练效果。

温度参数 τ \tau τ的作用

对于监督对比学习,温度参数 τ \tau τ的大小会影响特征分布的惩罚程度。具体来说,当温度较低时,对特征分布的惩罚较高,这实际上鼓励特征更加集中分布在一起。随着温度的增加,相对惩罚趋向更加均匀,这会使得特征的分布更加均匀。

温度参数 τ \tau τ在对比损失函数中的作用是控制样本之间的相似性度量。较低的温度值会增加样本之间的差异性,导致特征更加集中在一起,因为相似的样本对会更加突出,而不相似的样本对会有更大的惩罚。这会导致特征分布更加尖锐,即特征更加集中在一些特定的区域。

相反,较高的温度值会减少对样本之间差异性的强调,导致更加均匀的特征分布。这是因为相似性度量相对均匀,不同样本对之间的差异性减少,特征分布更加均匀。

因此,当温度较小时,对比损失函数更加强调样本之间的差异性,导致特征更加集中;而当温度较大时,对比损失函数更加均匀化样本之间的相似性度量,导致特征分布更加均匀。

温度越低,feature distribution越紧致。为了使得同一subclass越紧致,那就让 τ 1 < τ 2 \tau_1 < \tau_2 τ1<τ2

公式5解析

根据引用的参考文献[40],对于类别 c c c,我们定义 ϕ ( c ) \phi(c) ϕ(c)如下所示:

ϕ ( c ) = ∑ i = 1 n c ∣ z i − t c ∣ 2 n c log ⁡ ( n c + α ) (5) \phi(c)=\frac{\sum_{i=1}^{n_{c}}\left|z_{i}-t_{c}\right|{2}}{n{c} \log \left(n_{c}+\alpha\right)} \tag{5} ϕ(c)=nclog(nc+α)i=1nczitc2(5)
ϕ ( c ) \phi(c) ϕ(c)是用来衡量类别 c c c内实例与类别中心点之间的平均距离的度量。具体来说, ϕ ( c ) \phi(c) ϕ(c)的计算方式是将类别内每个实例与类别中心点的欧氏距离求和,然后除以类别中实例数量 n c n_c nc和一个对数项,以此来表示平均距离。在公式中, t c t_c tc代表类别 c c c的中心点, z i z_i zi代表类别 c c c的实例, α \alpha α是一个超参数用于调节公式的缩放,确保 ϕ ( c ) \phi(c) ϕ(c)不会过大。

通过计算 ϕ ( c ) \phi(c) ϕ(c),我们可以了解类别 c c c内实例与类别中心点的平均距离,从而衡量类别内实例的聚集程度。如果 ϕ ( c ) \phi(c) ϕ(c)较小,则表示类别内实例更加紧密地聚集在一起;而如果 ϕ ( c ) \phi(c) ϕ(c)较大,则表示类别内实例分布较散,距离类别中心点较远。

ϕ ( c ) \phi(c) ϕ(c)的值被用来计算类别 c c c的温度 τ 2 ( c ) \tau_{2}(c) τ2(c),以便根据类别内实例的聚集程度调整温度值,从而影响特征的分布,使得类别内实例的特征形成比子类别更松散的聚类。因此, ϕ ( c ) \phi(c) ϕ(c)在这里起到了指导调整温度值以控制特征分布的作用。

其中, t c t_{c} tc是类别 c c c的中心点(centroid), α \alpha α是一个超参数,用于确保 ϕ ( c ) \phi(c) ϕ(c)不会过大,而 z i z_{i} zi对应于属于类别 c c c的实例。从这个公式可以看出,如果当前到类别中心的平均距离较大,或者类别包含的数据较少,那么温度将会被设置为较大的值,以便在训练过程中采用类别 c c c的特征分布。接着,我们定义类别 c c c的温度为:

τ 2 ( c ) = τ 1 ⋅ exp ⁡ ( ϕ ( c ) 1 C ∑ i = 1 C ϕ ( i ) ) (6) \tau_{2}(c)=\tau_{1} \cdot \exp \left(\frac{\phi(c)}{\frac{1}{C} \sum_{i=1}^{C} \phi(i)}\right) \tag{6} τ2(c)=τ1exp(C1i=1Cϕ(i)ϕ(c))(6)

这样, τ 2 ( c ) \tau_{2}(c) τ2(c)对于类别标签始终大于 τ 1 \tau_{1} τ1对于簇标签(因为 ϕ ( c ) > 0 \phi(c)>0 ϕ(c)>0),并且能反映类别中实例集中程度的当前水平。具体来说,提出的 τ 2 ( c ) \tau_{2}(c) τ2(c)鼓励类别 c c c中实例的特征形成比子类别更松散的聚类(通过 τ 2 ( c ) > τ 1 \tau_{2}(c)>\tau_{1} τ2(c)>τ1),同时自适应地调整温度以避免过度松散或过度密集的聚类。

Algorithm1解析

这个算法描述的是一种自适应的子类平衡聚类方法。让我来解释一下它的主要步骤和目的:

  1. 算法输入:

    • 样本集 S
    • 阈值 M
    • 迭代次数 K
  2. 主要流程:

    • 进行K轮迭代
    • 第一轮选择距离较远的点作为初始聚类中心
    • 之后的轮次更新聚类中心为该类所有样本的平均
    • 每轮中,不断将样本分配给最相似的聚类中心
    • 如果某个类的样本数超过阈值M,则删除该聚类中心
  3. 算法目的:

    • 通过迭代优化聚类结果
    • 控制每个类的样本数量,实现子类平衡
    • 自适应地调整聚类中心数量
  4. 主要特点:

    • 使用余弦相似度衡量样本与聚类中心的相似性
    • 通过删除过大的类来平衡各个子类
    • 聚类中心数量可以动态变化

总的来说,这是一种旨在得到平衡的子类划分的自适应聚类算法,它能根据数据分布自动调整聚类结果。

“第一轮选择距离较远的点作为初始聚类中心
之后的轮次更新聚类中心为该类所有样本的平均”这两步具体是什么意思呢?请你举个例子详细说明

举例说明如下:

  1. 第一轮选择距离较远的点作为初始聚类中心:

这一步的目的是为了选择初始的聚类中心,使它们彼此之间尽可能分散。这样可以更好地覆盖整个数据空间。

例子:
假设我们有一个二维平面上的数据集,包含以下点:
A(1,1), B(2,2), C(8,8), D(9,9), E(5,5)

步骤:
a. 随机选择第一个中心,比如A(1,1)
b. 计算所有点到A的距离,选择最远的点作为第二个中心,这里是D(9,9)
c. 如果需要第三个中心,就选择到A和D距离之和最大的点,这里可能是E(5,5)

这样,我们得到的初始聚类中心是分散的: A(1,1), D(9,9), E(5,5)

  1. 之后的轮次更新聚类中心为该类所有样本的平均:

在后续的迭代中,每个聚类的中心会被更新为属于该类的所有样本的平均位置。

继续上面的例子:
假设经过一轮聚类后,我们得到了以下分类:
类1: A(1,1), B(2,2)
类2: C(8,8), D(9,9)
类3: E(5,5)

更新聚类中心:
类1新中心: ((1+2)/2, (1+2)/2) = (1.5, 1.5)
类2新中心: ((8+9)/2, (8+9)/2) = (8.5, 8.5)
类3新中心: (5, 5) (保持不变,因为只有一个点)

这个过程会在每次迭代中重复,使聚类中心逐渐移动到更能代表该类所有样本的位置。

这两个步骤的组合可以帮助算法找到更好的聚类结果:第一步确保初始中心分散,覆盖整个数据空间;第二步则通过迭代优化中心位置,使其更好地代表各自的类。

Algorithm 2

这个算法描述的是一种训练特征提取器的方法,主要结合了子类平衡和对比学习。这个算法的主要步骤和目的:

  1. 算法输入:

    • 数据集 D
    • 聚类分配更新间隔 K
    • 预热阶段的轮数 T₀
    • 总训练轮数 T
    • 超参数 β 和 δ
  2. 主要流程:
    a. 初始化模型参数 θ
    b. 使用SCL(监督对比学习)或KCL(知识对比学习)进行T₀轮预热训练
    c. 从T₀到T轮进行主要训练:

    • 每K轮更新一次聚类分配和温度参数
    • 使用公式4进行子类平衡对比学习
  3. 算法目的:

    • 训练一个能够提取平衡和有代表性特征的模型
    • 通过对比学习和子类平衡来改善特征提取器的性能
  4. 主要特点:

    • 包含预热阶段,可能用于初始化或预训练
    • 定期更新聚类分配,可能是为了适应模型的变化
    • 使用动态温度参数,可能用于调整学习难度
    • 结合了子类平衡和对比学习的思想

这个算法的核心思想是通过对比学习来训练特征提取器,同时考虑到子类的平衡性。它通过定期更新聚类和温度参数来适应训练过程中的变化,可能旨在解决类别不平衡或长尾分布等问题。

Algorithm1和Algorithm2之间的关系

算法1和算法2之间存在密切的关系,它们是一个完整训练过程的两个关键组成部分。让我解释一下它们的关系:

  1. 互补功能:

    • 算法1(子类平衡自适应聚类)主要负责数据的聚类和子类划分。
    • 算法2(训练算法)使用这些聚类结果来训练特征提取器。
  2. 数据流动:

    • 算法1的输出(聚类分配)成为算法2的输入之一。
    • 在算法2中,每隔K轮就会更新聚类分配,这个更新过程很可能是调用算法1来完成的。
  3. 迭代优化:

    • 算法2在训练过程中periodically更新聚类,这意味着它可能多次调用算法1。
    • 每次调用算法1时,都会使用当前训练的特征提取器f_θ(x)来进行聚类,这样聚类结果会随着特征提取器的改进而优化。
  4. 子类平衡的实现:

    • 算法1通过控制每个类的样本数量来实现子类平衡。
    • 算法2利用这种平衡的子类划分来进行对比学习,从而训练出更好的特征提取器。
  5. 动态适应:

    • 算法1能够动态调整聚类中心的数量。
    • 算法2利用这种动态调整的结果来更新其训练策略,特别是在更新温度参数τ₂时。

总的来说,算法1和算法2形成了一个闭环:算法1提供平衡的子类划分,算法2使用这些划分来训练特征提取器,而improved的特征提取器又能帮助算法1产生更好的聚类结果。这种相互促进的关系有助于整体模型性能的提升,特别是在处理不平衡或长尾分布数据时。

代码解读

class SupConLoss(nn.Module):

    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.

    It also supports the unsupervised contrastive loss in SimCLR"""

    def __init__(self,temperature=0.1, contrast_mode='all',

                 base_temperature=0.1):

        super(SupConLoss, self).__init__()

        self.temperature = temperature

        self.contrast_mode = contrast_mode

        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None): # # features.shape = (1024, 2, 128)     target.shape = (1024)

        """Compute loss for model. If both `labels` and `mask` are None,

        it degenerates to SimCLR unsupervised loss:

        https://arxiv.org/pdf/2002.05709.pdf

        Args:

            features: hidden vector of shape [bsz, n_views, ...].

            labels: ground truth of shape [bsz].

            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j

                has the same class as sample i. Can be asymmetric.

        Returns:

            A loss scalar.

        """

        device = (torch.device('cuda')

                  if features.is_cuda

                  else torch.device('cpu'))

        if len(features.shape) < 3:

            raise ValueError('`features` needs to be [bsz, n_views, ...],'

                             'at least 3 dimensions are required')

        if len(features.shape) > 3:

            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0] # batch_size = 1024

       

        if labels is not None and mask is not None: # label.shape = (1024) mask = None

            raise ValueError('Cannot define both `labels` and `mask`')

        elif labels is None and mask is None:

            mask = torch.eye(batch_size, dtype=torch.float32).to(device)

        elif labels is not None:

            labels = labels.contiguous().view(-1, 1)  # labels.shape = (1024, 1)

            if labels.shape[0] != batch_size:

                raise ValueError('Num of labels does not match num of features')

            mask = torch.eq(labels, labels.T).float().to(device) # 当提供labels时,根据labels生成mask矩阵,其中相同类的样本对比位置为1

        else:

            mask = mask.float().to(device)

        contrast_count = features.shape[1] # contrast_count = 2

        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) # (2048, 128)

        if self.contrast_mode == 'one':

            anchor_feature = features[:, 0]

            anchor_count = 1

        elif self.contrast_mode == 'all':

            anchor_feature = contrast_feature # anchor_feature.shape = (2048, 128)

            anchor_count = contrast_count  # contrast_mode有什么作用呢? anchor_count是干啥的?

        else:

            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

       

        # compute logits

        anchor_dot_contrast = torch.div(

            torch.matmul(anchor_feature, contrast_feature.T),

            self.temperature)

        # for numerical stability  为了数值稳定性,减去最大值

        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)

        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask   扩展掩码以适应anchor数量和对比数量,掩码掉对自身的对比

        mask = mask.repeat(anchor_count, contrast_count)

        # mask-out self-contrast cases

        logits_mask = torch.scatter(

            torch.ones_like(mask),

            1,

            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),

            0

        )

        mask = mask * logits_mask

        # compute log_prob

        exp_logits = torch.exp(logits) * logits_mask

        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive

        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss

        loss = - (self.temperature/ self.base_temperature) * mean_log_prob_pos

        loss = loss.view(anchor_count, batch_size).mean()

        return loss

这个类 SupConLoss 实现了监督对比损失(Supervised Contrastive Loss),它能够同时支持监督对比学习和无监督对比学习 (SimCLR)。这个类的设计基于论文 “Supervised Contrastive Learning” 中提出的算法,该算法旨在通过不同样本对之间的相似性和差异性学习一个更好的特征表示。

这里是该类的详细分解和解释:

类和初始化

class SupConLoss(nn.Module):
    def __init__(self, temperature=0.1, contrast_mode='all', base_temperature=0.1):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

这个类从 PyTorch 的 nn.Module 继承,并定义了一些初始化参数:

  • temperature: 对比损失的温度系数,控制对比损失的"平滑"程度。
  • contrast_mode: 对比的模式,有两个可能的值 'all''one''all' 包含所有视图,而 'one' 仅包括一个视图。
  • base_temperature: 基础温度,用于调整损失计算中的温度。

forward 函数

def forward(self, features, labels=None, mask=None):

forward 方法是核心部分,计算对比损失。它接收以下参数:

  • features: 形状为 [bsz, n_views, ...] 的隐藏特征向量。
  • labels: 形状为 [bsz] 的标签,用于监督对比学习。
  • mask: 对比掩码矩阵, 形状为 [bsz, bsz]。当两个样本属于同一类时,mask_{i,j}=1。 当提供labels时,根据labels生成mask矩阵,其中相同类的样本对比位置为1

设备检测和输入验证

device = (torch.device('cuda') if features.is_cuda else torch.device('cpu'))
if len(features.shape) < 3:
    raise ValueError('`features` needs to be [bsz, n_views, ...], at least 3 dimensions are required')
if len(features.shape) > 3:
    features = features.view(features.shape[0], features.shape[1], -1)
batch_size = features.shape[0]

首先,代码检测是否在GPU上运行,然后验证输入特征的形状。如果特征维度超过三维,则将其重新整形。

标签和掩码处理

if labels is not None and mask is not None:
    raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
    mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
    labels = labels.contiguous().view(-1, 1)
    if labels.shape[0] != batch_size:
        raise ValueError('Num of labels does not match num of features')
    mask = torch.eq(labels, labels.T).float().to(device)
else:
    mask = mask.float().to(device)

这部分逻辑处理标签和掩码的兼容性。如果同时存在 labelsmask 会报错;如果两者都不存在,则默认使用对角矩阵作为掩码;如果存在 labels,则根据 labels 生成掩码矩阵。

特征整理与对比模式

contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
    anchor_feature = features[:, 0]
    anchor_count = 1
elif self.contrast_mode == 'all':
    anchor_feature = contrast_feature
    anchor_count = contrast_count
else:
    raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

这里首先对特征进行整理,将特征展开并连接。根据 contrast_mode 的设置,选择对比模式是 'one' 还是 'all' ,决定要用哪些特征作为锚点特征。

计算对比损失

anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

这部分计算了锚点特征和对比特征的点积(相似性),并除以温度以获得 logits。为了数值稳定性,减去最大的 logit

contrast_mode的不同对比模式分别是什么意思?有什么作用?有什么不同呢?

contrast_mode 参数在 SupConLoss 类中提供了两种不同的对比模式:'one''all'。这两种模式的主要区别在于它们在对比学习过程中选择的锚点(anchor)特征不同。

让我们详细解释一下这两种模式及其区别和用途。

'one' 模式

含义
'one' 模式下,每个样本的第一个视图(view)作为锚点特征来进行对比。

作用

  • 这种模式仅使用每个样本的第一个视图作为对比的基础。
  • 适用于某些场景,例如只有单一视角(view)的数据,或你仅希望对一个特定视角进行对比学习。

实现

if self.contrast_mode == 'one':
    anchor_feature = features[:, 0]  # 只取第一个视图的特征
    anchor_count = 1                # 锚点数目为1

'all' 模式

含义
'all' 模式下,将所有视图的特征展开,并将每个视图的特征都作为锚点进行对比。

作用

  • 这种模式下,每个样本的每个视角都用作锚点,自我对比。
  • 适用于多视角(view)的数据,利用所有可能的信息进行对比训练,可以更充分地利用数据。

实现

elif self.contrast_mode == 'all':
    anchor_feature = contrast_feature  # 取所有视图展开后的特征
    anchor_count = contrast_count      # 锚点数目为视图数

区别与应用

模式含义作用
'one'仅使用每个样本的第一个视图作为锚点对于仅希望对一个特定视角进行对比学习的场景
'all'使用所有视图的特征,并将每个视图都作为锚点更充分地利用多视角数据,适用于有多个视图的场景

举例说明

假设我们有一批次的数据,每个样本有两个视角(n_views = 2)。

  • 假设特征矩阵 features 形状为 (batch_size, 2, feature_dim):

'one' 模式

features[:, 0]  # 形状为 (batch_size, feature_dim)

每个样本的第一个视图作为锚点,不考虑第二个视图。

'all' 模式

contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)  # 形状为 (2 * batch_size, feature_dim)

所有视图展开,然后作为锚点。例如,假设 feature_dim=128batch_size=1024

  • 展开后 contrast_feature 形状为 (2048, 128),包含所有视图的特征。

总结

  • 'one' 模式简单,只考虑一个视角,适用于单一视角的数据,或在某些特定场景下希望只使用特定视角。
  • 'all' 模式充分利用所有视角,适用于多视角数据,更加全面,可以挖掘更多信息。

根据你的数据和应用场景选择合适的模式,可以取得更好的效果。

但是要是考虑多视角的话,那这个contrast_feature的batch_size就变为原来的两倍了哎,为啥要这样设置?为啥考虑所有视角的话,是把batchsize翻倍呢

解释为什么在 'all' 模式下批次大小会翻倍,并进一步说明这样设置的原因。

详细解释

假设我们有以下变量:

  • batch_size = N
  • n_views = 2 例如图像的两种不同增强。
  • feature_dim = 128 即特征向量维度。

原始特征矩阵 features

形状为 [N, n_views, feature_dim]。例如,features 可能是:

features = torch.randn(N, 2, 128)

'all' 模式下的特征展开

首先,使用 torch.unbind(features, dim=1) 将特征矩阵沿着 n_views 维度拆分为两个 (N, 128) 的子矩阵。这是为了获取每个视角的特征。然后,通过 torch.cat(..., dim=0) 将这两个子矩阵在第一维度(batch 维度)上重新拼接,生成一个新的特征矩阵 contrast_feature

contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

结果,contrast_feature 的形状为 [2*N, 128]

意图:最大化数据利用

通过这种方式,每个视角的特征将分别进行对比学习。举一个具体的例子更加清楚:

示例

假设 batch_size 是 2,每个样本有 2 个视角。

  1. 原始特征矩阵 features
features = 
[
  [[f11, f12, ..., f1_128], [f21, f22, ..., f2_128]],   # 第一个样本的两个视角
  [[f31, f32, ..., f3_128], [f41, f42, ..., f4_128]]    # 第二个样本的两个视角
]

形状是 [2, 2, 128]

  1. 进行 torch.unbindtorch.cat
contrast_feature = 
[
  [f11, f12, ..., f1_128],
  [f31, f32, ..., f3_128],
  [f21, f22, ..., f2_128],
  [f41, f42, ..., f4_128]
]

形状变为 [4, 128],即 2 * 2, 128

损失计算

'all' 模式下,所有特征都作为对比学习的锚点和对比特征。通过这种方式,损失能考虑所有视角的特性,防止信息丢失,达到更加全面、更有效的对比学习。

优势

这种设计具有以下主要优势:

  1. 充分利用数据:确保每一个视角的信息都被用到,不浪费任何数据。
  2. 增强表示能力:通过多视角对比,可以获得更加鲁棒的特征表示,适用于泛化性能更好的模型。

总结

总之,contrast_feature 的批次大小之所以会翻倍,是为了在 'all' 模式下最大化地利用每个视角的信息。通过将所有视角展开并参与对比学习,可以确保学习到的特征更具代表性和鲁棒性。

我懂了,这个也就是相当于,把不同的channel给搞到batch上去,然后对每个channel的feature都进行对比。比如原始数据的shape为(batch, channel, feature_dim),然后选择all模式,就要把channel展平到batch维度上去,也就是数据shape变为(batch*channel, feature_dim),这样做的好处就是在进行对比学习的时候,可以转化为矩阵相乘,方便运算,而且也考虑到了每个channel的对比。我理解的对吗?

是的,你的理解是正确的!在多视角对比学习(如 'all' 模式)中,将不同视角(你称为“channels”)展开到批次维度(batch dimension)上去进行对比,确实是为了便于运算和充分利用每个视角的信息。这样做的主要好处包括:

  1. 方便矩阵操作:将多视角展开到批次维度上后,可以利用高效的矩阵乘法操作来计算对比损失,充分利用 GPU 等硬件加速的优势。

  2. 全面对比学习:这样确保每个视角的特征都参与到对比过程中,防止信息丢失,从而使得学到的表示更加全面和鲁棒。

更详细的实现思路

假设 features 的形状为 (batch_size, n_views, feature_dim),通过 'all' 模式的展开操作实现将 n_views 展平到批次维度上:

contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)  # 这个操作将 shape 从 (batch_size, n_views, feature_dim) 变为 (batch_size * n_views, feature_dim)

举具体的例子来更好地理解:

示例

假设输入特征矩阵 features 形状为 (4, 2, 128),这里 4 是批次大小(batch size),2 是视角数(n_views),128 是特征维度(feature dimension)。

原始 features 的示例数据结构大致如下:

features = [
             [[f11, f12, ..., f1_128], [f21, f22, ..., f2_128]],  # 第一个样本的两个视角
             [[f31, f32, ..., f3_128], [f41, f42, ..., f4_128]],  # 第二个样本的两个视角
             [[f51, f52, ..., f5_128], [f61, f62, ..., f6_128]],  # 第三个样本的两个视角
             [[f71, f72, ..., f7_128], [f81, f82, ..., f8_128]]   # 第四个样本的两个视角
           ]

通过 torch.unbindtorch.cat 操作得到展开后的 contrast_feature

contrast_feature = [
                     [f11, f12, ..., f1_128],  # 第一个样本的第一个视角
                     [f21, f22, ..., f2_128],  # 第一个样本的第二个视角
                     [f31, f32, ..., f3_128],  # 第二个样本的第一个视角
                     [f41, f42, ..., f4_128],  # 第二个样本的第二个视角
                     [f51, f52, ..., f5_128],  # 第三个样本的第一个视角
                     [f61, f62, ..., f6_128],  # 第三个样本的第二个视角
                     [f71, f72, ..., f7_128],  # 第四个样本的第一个视角
                     [f81, f82, ..., f8_128]   # 第四个样本的第二个视角
                   ]

形状变为 (8, 128),即 (4 * 2, 128)

计算损失

在展开的特征空间中,损失计算通过对比各个视角的特征,通过矩阵乘法等高效操作进行。假设我们使用的是对比损失(如 InfoNCE):

  1. 计算相似性矩阵:可以通过特征矩阵内积快速计算:
    similarity_matrix = torch.matmul(contrast_feature, contrast_feature.T)
    
  2. 构造对比学习损失:通过交叉熵损失等方法计算对比损失。

优势总结

  • 计算高效:对比学习中的矩阵操作能充分利用硬件加速。
  • 信息全面:每个视角(channel)的信息都能被学习到,增强模型的表示能力。

你的总结非常准确,这样的设计确实是为了提高计算效率和全面利用数据的每个视角。希望这些解释能更好地帮助你理解多视角对比学习为什么要这样展开批次维度。

  • 12
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值