对比学习系列(三)-----SimCLR

SimCLR

SimCLR通过隐藏空间的对比损失最大化相同数据在不同增广下的一致性来学习表达。SimCLR框架有四个主要的组件,分别是:数据增广,encode网络,projection head网络和对比学习函数。

在这里插入图片描述
对于数据 x x x,从同一个数据增广族中抽取两个独立的数据增广算子( t ∼ T t \sim T tT t ′ ∼ T {t}' \sim T tT),以获得两个相关的视图 x ^ i \hat{x}_{i} x^i x ^ j \hat{x}_{j} x^j x ^ i \hat{x}_{i} x^i x ^ j \hat{x}_{j} x^j是一对正样本,然后一个神经网络编码器 f ( ⋅ ) f\left( \cdot \right) f()从增广的数据中提取特征 h i = f ( x ^ i ) , h j = f ( x ^ j ) , h_{i}=f\left( \hat{x}_{i} \right), h_{j}=f\left( \hat{x}_{j} \right), hi=f(x^i),hj=f(x^j),。再然后一个小的神经网络project head g ( ⋅ ) g\left( \cdot \right) g()将特征映射到对比损失的空间。project head采用带有一个隐含层的MLP获取 z i = g ( h i ) = W ( 2 ) σ ( W ( 1 ) h i ) z_{i} = g\left( h_{i} \right) = W^{\left( 2 \right)} \sigma \left( W^{\left( 1 \right)} h_{i}\right) zi=g(hi)=W(2)σ(W(1)hi)

对于包含一对正样本 x ^ i \hat{x}_{i} x^i x ^ j \hat{x}_{j} x^j的集合 { x ^ k } \{ \hat{x}_{k} \} {x^k},对比预测任务目的是对于给定的 x ^ i \hat{x}_{i} x^i { x ^ } k ≠ i \{ \hat{x} \}_{k \neq i} {x^}k=i中识别出 x ^ j \hat{x}_{j} x^j。随机挑选 N N N个样本组成一个minibatch,这个minibatch中则有 2 N 2N 2N个数据样本,将其他 2 ( N − 1 ) 2\left( N - 1\right) 2(N1)个扩增的样本作为这个minibatch中的负样本,设 s i m ( u , v ) = u T v / ∥ u ∥ ∥ v ∥ sim\left( u, v\right) = u^{T}v / \| u\| \| v\| sim(u,v)=uTv/∥u∥∥v表示 l 2 l_{2} l2正则化后你的 u u u v v v的点积,那么对一对正样本 ( i , j ) \left( i, j \right) (i,j),损失函数如下定义:

l i , j = − l o g e x p ( s i m ( z i , z j ) / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] e x p ( s i m ( z i , z k ) / τ ) l_{i,j} = - log \frac{exp\left( sim \left( z_{i}, z_{j}\right) / \tau \right)}{\sum_{k=1}^{2N} \mathbb{1}_{[ k \neq i]} exp\left( sim \left( z_{i}, z_{k}\right) / \tau \right)} li,j=logk=12N1[k=i]exp(sim(zi,zk)/τ)exp(sim(zi,zj)/τ)

最后的损失函数计算一个minibatch中的所有的正样本对,包括 ( i , j ) \left( i, j \right) (i,j) ( j , i ) \left( j,i \right) (j,i)。下面是SimCLR的伪代码。从伪代码中可以看出,编码器 f ( ⋅ ) f\left( \cdot \right) f()和project head g ( ⋅ ) g\left( \cdot \right) g() 在训练时都会被更新参数,但是只有编码器 f ( ⋅ ) f\left( \cdot \right) f()用于下游任务。
在这里插入图片描述
simCLR不采用memory bank的形式进行训练,而是加大batchsize,bacth size为8192,对于每一个正样本,将会有16382张负样本实例。增大batch size其实相当于每个minibatch时动态生成一个memory bank。论文中发现使用标准的SGD/Momentum,大batch size训练时是不稳定的,论文中采用LARS优化器。

参考

  1. The Illustrated SimCLR Framework
  2. SimCLR
  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
下面是一个简单的示例,展示了如何使用PyTorch实现simCLR的NT-Xent损失函数: ```python import torch import torch.nn as nn import torch.nn.functional as F class NTXentLoss(nn.Module): def __init__(self, temperature=0.5): super(NTXentLoss, self).__init__() self.temperature = temperature def forward(self, z1, z2): batch_size = z1.size(0) # 计算相似性矩阵 sim_matrix = torch.matmul(z1, z2.t()) / self.temperature # 构造标签 labels = torch.arange(batch_size).to(z1.device) # 计算正样本的损失 pos_loss = F.cross_entropy(sim_matrix, labels) # 计算负样本的损失 neg_loss = F.cross_entropy(sim_matrix.t(), labels) # 总损失为正样本损失和负样本损失之和 loss = pos_loss + neg_loss return loss ``` 在这个代码中,我们定义了一个名为NTXentLoss的自定义损失函数类。它接受两个输入张量z1和z2,这些张量表示两个不同的样本的特征表示。其中,z1和z2的形状应该都是(batch_size, feature_dim)。temperature参数用于缩放相似性矩阵。 在forward方法中,我们首先计算了z1和z2之间的相似性矩阵,然后使用相似性矩阵和标签(labels)计算正样本的损失和负样本的损失。最后,我们将正样本损失和负样本损失相加得到总的损失。 这只是一个简单的示例,实际实现中可能需要进行一些额外的处理和调整,具体取决于实验的要求和模型的结构。 相关问题: - simCLR中的NT-Xent损失函数是如何帮助模型学习到更好的特征表示的? - simCLR中的temperature参数的作用是什么?如何选择合适的值? - 除了NT-Xent损失函数,simCLR还有哪些关键的组成部分? - 在实际应用中,如何使用simCLR训练一个图像特征提取器?

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

马鹤宁

谢谢

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值