对比学习系列(一)---InfoNce损失理解+MoCo伪代码

InfoNce损失理解

  1. InfoNCE损失和Cross Entropy损失的关系。
    先看Cross Entropy损失:
    L o s s = − ∑ i = 1 n y i l o g ( y ^ i ) Loss=-\sum_{i=1}^{n}y_i log(\hat y_i) Loss=i=1nyilog(y^i)
    在有监督学习下,ground-truth是一个one-hot向量,对softmax的结果取-log,再与ground-truth相乘之后得到交叉熵损失。
    L o s s = − l o g e x p ( z ′ ) ∑ i = 1 n e x p ( z i ) Loss=-log \frac{exp(z^{\prime})}{\sum_{i=1}^{n} exp(z_i)} Loss=logi=1nexp(zi)exp(z)
    上式中的n指的是有监督学习中数据集一共有的类别,比如ImageNet为1000个类别。
    交叉熵通常是 衡量两个概率分布之间差异的指标。(在分类问题中,通常有一个真实的概率分布P(一般为一个one-hot编码,代表样本的真实标签))和一个模型预测的概率分布Q。
    C E ( P , Q ) = − ∑ i = 1 n P ( i ) l o g ( Q ( i ) ) CE(P,Q)=-\sum_{i=1}^{n}P(i)log(Q(i)) CE(P,Q)=i=1nP(i)log(Q(i))
    交叉熵损失函数优化的目标是使得模型预测的概率分布尽量和真实标签的概率分布接近。
    再看NCE损失函数:
    其核心思想在于将多分类问题转换为二分类问题,解决多分类问题中类别太多时 softmax 的计算问题。一类是数据类别,另一类是噪声类别,学习数据样本和噪声样本之间的区别。
    但是把整个数据集剩下的数据都当作负样本,虽然接近了类别多的问题,但是计算机复杂度并没有降下来,解决方法便是做负样本采用来计算loss,就是estimation的意思。一般来说,负样本取得越多,效果越好。
    再看InfoNce损失函数:
    如果只把问题划分为二分类问题,对模型学习并不友好,很多噪声样本可能就不是一个类,因此需要把它们看成多分类问题。
    InfoNCE Loss  ⁡ = − 1 N ∑ i = 1 N log ⁡ ( exp ⁡ ( q i ⋅ k i + τ ) ∑ j = 1 N exp ⁡ ( q i ⋅ k j − τ ) ) \operatorname{InfoNCE~Loss~}=-\frac{1}{N} \sum_{i=1}^N \log \left(\frac{\exp \left(\frac{q_i \cdot k_{i+}}{\tau}\right)}{\sum_{j=1}^N \exp \left(\frac{q_i \cdot k_{j-}}{\tau}\right)}\right) InfoNCE Loss =N1i=1Nlog j=1Nexp(τqikj)exp(τqiki+)
    分母的和其实就是一个正样本和N-1个负样本做的,一共N个样本。
    InfoNce loss其实就相当于一个Cross Entropy loss,做的是一个N类的分类任务,目的是将图片分到这个相同类别中。

  2. InfoNCE损失的温度系数

MoCo伪代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear1 = nn.Linear(in_dim, out_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(out_dim, out_dim)
    
    def forward(self, x):
        x = self.linear2(self.relu(self.linear1(x)))
        return x
        
class MoCo(nn.Module):
    def __init__(self, base_encoder, in_dim, out_dim, K=256, m=0.9, T=0.07):
        super().__init__()
        
        self.m = m # 动量更新系数
        self.K = K # 队列大小
        self.T = T # 温度系数
        
        # 创建编码器
        self.encoder_q = base_encoder(in_dim, out_dim)
        self.encoder_k = base_encoder(in_dim, out_dim)
        
        # 创建动态队列
        self.register_buffer("queue", torch.rand(K, out_dim))
        self.queue = F.normalize(self.queue, dim=1)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
        
        # 初始化权重
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False
        
    # 使用动量更新key编码器参数
    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q * (1 - self.m)
    
    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # 取出队列指针的位置
        batch_size = keys.shape[0]
        
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # 确保队列大小是batch_size的整数倍
        
        # 更新队列中的keys
        self.queue[ptr:ptr+batch_size, :] = keys
        ptr = (ptr + batch_size) % self.K
        
        self.queue_ptr[0] = ptr
     
    def forward(self, im_q, im_k):
        # 计算query编码
        q = self.encoder_q(im_q)
        q = F.normalize(q, dim=1)
        
        # 计算key编码
        with torch.no_grad():
            self._momentum_update_key_encoder() # 更新key的参数
            k = self.encoder_k(im_k)
            k = F.normalize(k, dim=1)
        
        # 计算相似度
        # positive logits: Nx1
        s_pos = torch.sum(q*k, dim=1).unsqueeze(dim=1)
        # negative logits: NxK
        s_neg = torch.matmul(q, self.queue.clone().detach().T)
        
        # 拼接相似度 logits: Nx(1+K)
        logits = torch.cat([s_pos, s_neg], dim=1)
        logits /= self.T
        
        # 创建标签
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        
        # 更新队列
        self._dequeue_and_enqueue(k)

        # 计算InfoNCE损失
        loss = F.cross_entropy(logits, labels)
        
        return loss
    

def test_moco():
    # 初始参数
    batch_size = 4
    feature_dim = 2048
    encoder_dim = 256
    queue_size = 32
    
    # 创建MoCo模型
    model = MoCo(base_encoder=MLP,in_dim=feature_dim, out_dim=encoder_dim, K=queue_size)
    
    # 创建随机输入数据
    im_q = torch.rand(batch_size, feature_dim)
    im_k = torch.rand(batch_size, feature_dim)
    
    # 将数据移到GPU上
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    im_q = im_q.to(device)
    im_k = im_k.to(device)
    
    # 前向传播并计算损失
    loss = model(im_q, im_k)
    
    # 打印结果
    print('Loss:', loss.item())

# 运行测试用例
test_moco()
  • 32
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
对比学习是一种基于相似性学习的方法,它通过比较不同样本之间的相似性来学习特征表示。SimCLR、InfoLossMOCO、BYOL都是最近几年提出的基于对比学习的预训练模型。 SimCLR是一种基于自监督学习的对比学习方法,它采用了一种新的数据增强方法,即随机应用不同的图像变换来生成不同的视图,并通过最大化同一视图下不同裁剪图像的相似性来训练模型。SimCLR在多个视觉任务上均取得了优异的表现。 InfoLoss是另一种基于自监督学习的对比学习方法,它通过最小化同一样本的不同视图之间的信息丢失来学习特征表示。InfoLoss可以通过多种数据增强方法来生成不同的视图,因此具有很强的可扩展性。 MOCO(Momentum Contrast)是一种基于动量更新的对比学习方法,它通过在动量更新的过程中维护一个动量网络来增强模型的表示能力。MOCO在自然语言处理和计算机视觉领域均取得了出色的表现。 BYOL(Bootstrap Your Own Latent)是一种基于自监督学习的对比学习方法,它通过自举机制来学习特征表示。BYOL使用当前网络预测未来的网络表示,并通过最小化预测表示与目标表示之间的距离来训练模型。BYOL在图像分类和目标检测任务上均取得了很好的表现。 总体来说,这些对比学习方法都是基于自监督学习的,它们通过比较不同样本或不同视图之间的相似性来学习特征表示,因此具有很强的可扩展性和泛化能力。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值