【论文翻译】BYOL A New Approach to Self-Supervised Learning

BYOL是一种新的自监督学习方法,不依赖负样本对,通过在线和目标网络交互学习图像表示。在ImageNet上,BYOL在没有负样本对的情况下,使用ResNet-50达到74.3%的top-1分类准确率,优于对比方法。此外,它在半监督和transfer基准上表现与现有技术相当或更好,且对图像增强选择具有更强的鲁棒性。
摘要由CSDN通过智能技术生成

标题:BYOL 一种新的自监督学习方法

摘要

我们介绍BYOL,一种新的自监督图像表示学习的方法。BYOL依赖于两个神经网络,称为在线和目标网络,它们相互作用并相互学习。从图像的增强视图,我们训练在线网络来预测同一图像在不同增强视图下的目标网络表示。同时,我们用在线网络的缓慢移动平均值来更新目标网络。虽然最先进的方法本质上依赖于负样本对,但BYOL在没有负样本对的情况下达到了一个新的水平。使用标准线性评估协议和ResNet-50架构,BYOL在ImageNet上达到74.3%的top-1分类准确率,而使用更大的ResNet时达到79.6%。我们表明,BYOL在 transfer和半监督基准上的表现与目前的技术水平相当或更好。

1 介绍

学习良好的图像表示是计算机视觉的一个关键挑战[1,2,3],因为它允许对下游任务进行有效的训练[4,5,6,7]。已经提出了许多不同的训练方法来学习这种表示,通常依赖于视觉前置任务 pretext task。其中,最先进的对比方法[8,9,10,11,12]是通过减小同一图像的不同增强视图的表示之间的距离(“正样本对”)和增加来自不同图像的增强视图的表示之间的距离(“负样本对”)来训练的。这些方法需要通过依赖大批量[8,12],记忆库[9]或定制挖掘策略[14,15]来检索负对,从而仔细处理负对[13]。此外,它们的性能主要取决于图像增强的选择[8,12]。

在这篇文章中,我们介绍了BYOL,一种新的自监督图像表示学习的方法。在不使用负样本对的情况下,BYOL获得了比最先进的对比方法更高的性能。它反复引导网络的输出作为增强表示的目标。此外,与对比方法相比,BYOL对图像增强的选择具鲁棒性;我们怀疑,不依赖负对是其鲁棒性提高的主要原因之一。虽然以前基于自举的方法使用伪标签或聚类索引[16,17]作为目标,但我们建议直接自举表示。特别是,BYOL使用两个神经网络,称为在线和目标网络,相互作用和相互学习。从图像的增强视图开始,BYOL训练其在线网络来预测同一图像的另一增强视图的目标网络表示。即使这个问题可能遇到以外的结果,例如,对所有图像输出零,我们根据经验表明,使用在线网络的缓慢移动平均值作为目标网络足以避免这种崩溃到无效解。

我们使用ResNet架构评估了BYOL在ImageNet [18]和其他视觉基准上所学的表现[19]。在ImageNet上的线性评估协议下,包括在冻结表示之上训练线性分类器,BYOL在标准ResNet-50上达到74.3%的top-1精度,在更大的ResNet上达到79.6%的top-1精度(图1)。在ImageNet的半监督和 transfer设置中,我们获得了与当前技术水平相当或更好的结果。我们的贡献是:(1)我们引入了BYOL,一种自监督的表示学习方法(第3节),它在不使用负对的情况下,在ImageNet上的线性评估协议下获得最先进的结果。(二)我们表明,在半监督和 transfer基准测试中,我们所学的表现优于最先进的水平(第4节)。(三)我们表明,与对比对手相比,BYOL对批量和图像增强集的变化更有弹性(第5节)。特别是,当仅使用随机裁剪作为图像增强时,BYOL的性能下降比强对比基线SimCLR小得多。

在这里插入图片描述
图1:使用ResNet-50和我们的最佳体系结构ResNet-200 (2X)在ImageNet上的BYOL性能(线性评估),与其他无监督和有监督的(sup.)基线[8]相比较。

2 相关工作

大多数无监督的表示学习方法可以分为生成式或区分式[20,8]。表示学习的生成式方法在数据和潜在嵌入上建立分布,并将学习到的嵌入用作图像表示。这些方法中的许多依赖于图像的自动编码[21,22,23]或对抗学习[24],联合建模数据和表示[25,26,27,28]。生成式方法通常直接在像素空间中操作。然而,这在计算上是昂贵的,并且图像生成所需的高水平细节对于表示学习可能不是必需的。

在鉴别方法中,对比方法[9,10,29,30,31,11,32,33]目前在自监督学习[34,8,12]中取得了最先进的表现。对比方法通过使同一图像的不同视图的表示更接近(“正对”),并将来自不同图像的视图的表示分开(“负对”),避免了像素空间中昂贵的生成步骤[35,36]。对比方法通常需要将每个例子与许多其他例子进行比较才能很好地工作[9,8],这就提出了使用否定对是否必要的问题。DeepCluster [17]部分回答了这个问题。它在其表示的先前版本上使用引导来产生下一个表示的目标;它使用先前表示对数据点进行聚类,并使用每个样本的聚类索引作为新表示的分类目标。在避免使用负对的同时,这需要昂贵的聚类阶段和特定的预防措施,以避免陷入这种无效解。

一些自我监督的方法不是对比的,而是依赖于使用辅助的手工预测任务来学习它们的表示。特别是,相对补片预测[20,36],彩色灰度图像[37,38],图像修补[39],图像拼图[40],图像超分辨率[41],和几何变换[42,43]已被证明是有用的。然而,即使有合适的体系结构[44],这些方法也被对比方法[34,8,12]所超越。

我们的方法与自举延迟预测(PBL[45])有一些相似之处,自举延迟预测是一种用于强化学习的自监督表征学习技术。PBL联合训练代理的历史表现和未来观察的编码。观察编码被用作训练代理的表示的目标,而代理的表示被用作训练观察编码的目标。与PBL不同,BYOL使用其代表的缓慢移动平均值来提供其目标,并且不需要第二个网络。

在自我监督学习中,MoCo [9]使用移动平均网络(动量编码器)来保持从存储库中提取的负对的一致表示。相反,BYOL使用移动平均网络来产生预测目标,作为稳定引导步骤的手段。我们在第5节中表明,这种稳定效应也可以改进现有的对比方法。

3 模型

我们从激励我们的方法开始,然后在第3.1节解释它的细节。许多成功的自监督学习方法建立在[46]中介绍的交叉视图预测框架的基础上。通常,这些方法通过彼此预测同一图像的不同视图(例如,不同的随机裁剪)来学习表示。许多这样的方法将预测问题直接投射到表示空间中:图像的增强视图的表示应该是同一图像的另一增强视图的表示的预测。然而,直接在表示空间中进行预测会导致表示的折叠:例如,一个表示在视图之间是恒定的,它总是对自己有完全的预测。对比方法通过将预测问题重新表述为一个辨别问题来回避这个问题:从一个增强视图的表示中,他们学会辨别同一图像的另一个增强视图的表示和不同图像的增强视图的表示。在绝大多数情况下,这阻止了训练找到折叠的表示。然而,这种辨别方法通常需要将增强视图的每个表示与许多负面示例进行比较,以找到足够接近的示例,从而使辨别任务具有挑战性。在这项工作中,我们的任务是找出这些负面的例子对于防止崩溃和保持高性能是否不可或缺。

为了防止崩溃,一个简单的解决方案是使用一个固定的随机初始化的网络来产生我们预测的目标。虽然避免了崩溃,但从经验上看,它不会产生很好的表现。尽管如此,有趣的是,使用该过程获得的表示已经比初始固定表示好得多。在我们的消融研究(第5节)中,我们通过预测一个固定的随机初始化网络来应用该程序,并在ImageNet上的线性评估协议中实现了18.8%的top-1准确性(表5a),而随机初始化网络本身仅实现了1.4%。这个实验性的发现是BYOL的核心动机:从一个给定的表示(称为目标),我们可以通过预测目标表示来训练一个新的、潜在增强的表示(称为在线)。从那里,我们可以期望通过迭代这个过程来建立一个质量不断提高的表示序列,使用后续的在线网络作为进一步训练的新目标网络。在实践中,BYOL通过迭代地改进其表示来推广这种自举过程,但是使用在线网络的缓慢移动的指数平均值作为目标网络,而不是固定的检查点。(对最后一句话的注释 “目标网络已在[47]中介绍,并在深度强化学习中普遍使用[48,49,50]。在深度反向链路中,目标网络稳定由贝尔曼方程提供的引导更新,使得它们对稳定BYOL的引导机制有吸引力。虽然固定目标网络在深层R1中更常见,但BYOL使用了以前网络的加权移动平均,类似于[51],以便在目标表示中提供更平滑的变化” 。)

3.1 BYOL的描述
在这里插入图片描述
图2: BYOL的框架。BYOL最小化了qθ(z)和sg(z)之间的相似性损失,其中θ是训练的权重,ξ是θ的指数移动平均值,sg表示停止梯度。在训练结束时,除fθ之外的一切都被丢弃,y用作图像表示。

BYOL的目标是学习一种表示法y,然后用于下游任务。如前所述,BYOL使用两种神经网络进行学习:在线网络和目标网络。在线网络由一组权重θ定义,由三个阶段组成:编码器fθ、投影器gθ和预测器qθ,如图2所示。

目标网络具有与在线网络相同的体系结构,但使用不同的权重集ξ。目标网络提供回归目标来训练在线网络,其参数ξ是在线参数θ [51]的指数移动平均。更准确地说,给定一个目标衰减率τ ∈ [0,1],在每个训练步骤之后,我们执行以下更新,
在这里插入图片描述

给定一组图像D,从D中均匀采样的图像X~ D以及两个图像增强分布T和T0,BYOL通过分别应用图像增强t~ T和t0~T0从X产生两个增强视

V =∏T(x)
V0 =∏T0(x)。从第一个增强视图v,在线网络输出表示y =∏fθ(v)
和投影zθ=∏gθ(y)
。目标网络从第二增强视图v0输出
和目标投影z0ξ=∏gξ(y0)
。然后,我们输出z0ξ的预测值qθ(zθ),并将qθ(zθ)和z0ξ l2正则化为在这里插入图片描述。最后,我们定义了正则化预测和目标投影之间的以下均方误差(对这句话的注释:“虽然我们可以直接预测表示y,而不是投影z,以前的工作[8]已经根据经验表明,使用这种投影可以提高性能。” ):

在这里插入图片描述

我们将等式2中的损失LBYOL
对称化,分别将v0馈送到在线网络和v馈送到目标网络以计算LBYOL
。在每个训练步骤中,我们执行一个随机优化步骤,使LBYOL+LBYOL
相对于θ最小,而不是ξ,如图2中的停止梯度所示。

在训练结束时,我们只保留编码器fθ;如[9]所示。与其他方法相比,我们仅在最终表示fθ中考虑推理时间权重的数量。附录A总结了完整的培训过程,附录G提供了基于JAX [52]和Haiku[53]库的python伪代码。

3.2 实现细节

图像增强 BYOL使用与SimCLR [8]中相同的图像增强集。首先,通过随机水平翻转选择图像的随机块并将其大小调整到224 × 224,随后是颜色失真,由亮度、对比度、饱和度、色调调整和可选灰度转换的随机序列组成。最后,高斯模糊和solarization被用于随机块。关于图像增强的更多细节见附录B。

框架我们使用具有50层和后激活(ResNet-50(1×) v1)的卷积残差网络[19]作为我们的基本参数编码器fθ和fξ。我们还使用更深(50、101、152和200层)和更宽(从1倍到4倍)的ResNets,如[54、44、8]所示。具体而言,表示y对应于最终平均汇集图层的输出,其特征尺寸为2048(宽度乘数为1x)。如在SimCLR [8]中,表示y被多层感知器(MLP) gθ投影到更小的空间,并且对于目标投影gξ也是相同。该MLP包括输出尺寸为4096的线性层,随后是批量归一化[55]、校正线性单元(ReLU) [56],以及输出尺寸为256的最终线性层。与SimCLR相反,这个MLP的输出不是批量标准化的。预测器qθ采用与gθ相同的架构。

优化我们使用LARS优化器 [ 57 ] 和余弦衰减学习速率表 [ 58 ] ,没有重新启动,超过1000个epoch,预热周期为10个epoch。我们将基本学习率设置为0.2,随批次大小线性缩放[ 59 ] (学习率= 0.2 ×批次大小/256)。此外,我们使用1.5106
的全局权重衰减参数,同时从LARS自适应和权重衰减中排除偏差和批次标准化参数。对于目标网络,指数移动平均参数τ从τbase= 0.996开始,在训练期间增加到1。具体来说,我们设置000002
,k为当前训练步骤,K为最大训练步骤数。我们使用的批量大小为4096,分布在512个云TPU v3内核上。使用这种设置,ResNet-50(×1)的训练大约需要8个小时。附录g总结了所有的超参数。

4 实验评估

在对ImageNet ILSVRC-2012数据集的训练集进行自监督预处理后,我们评估了BYOL表示法的性能[18]。我们首先在 ImageNet (IN)上使用线性评估和半监督设置中对其进行评估。然后,我们在其他数据集和任务上测量其 transfer 能力,包括分类、分割、对象检测和深度估计。为了进行比较,我们还报告了使用来自训练 ImageNet子集 (称为监督输入)的标签训练的表示的效果。在附录E中,我们通过在再现该评估方案之前在place 365-标准数据集[60]上预处理一个表示来评估BYOL的一般性。

ImageNet 的线性评价我们首先按照[44,61,37,10,8]和附录C.1中描述的程序,通过在冻结的表示上训练线性分类器来评估BYOL表示;我们报告了表1中测试集的top-1和top-5准确度,单位为%。使用标准ResNet-50 (×1),BYOL获得了74.3%的top-1准确度(91.6%的top-5准确度),与之前的自监督技术相比提高了1.3%(分别为0.5%)[12]。这缩小了与监督baseline[8]76.5%的差距,但仍明显低于更强的监督baseline[62]78.9%。凭借更深、更宽的体系结构,BYOL始终超越了之前的技术水平(附录C.2),并获得了79.6%的top-1准确率的最佳性能,排名高于之前的自我监督方法。在ResNet-50 (4x)上,BYOL达到了78.6%,类似于[8]中相同体系结构的最佳监督baseline的78.9%。
在这里插入图片描述
ImageNet上的半监督训练接下来,我们评估在使用ImageNet训练集的一个小子集对分类任务上的BYOL表示进行微调时获得的性能,这次使用的是带标签信息。我们遵循附录C.1中详述的[61,63,8,29]的半监督协议,并使用与[8]中相同的分别为1%和10%的ImageNet标记的训练数据的固定分割。我们在表2的测试集上报告了top-1和top-5的准确度。在广泛的体系结构中,BYOL始终优于以前的方法。此外,如附录C.1所述,当微调超过100%的ImageNet标签时,使用ResNet-50,BYOL达到了77.7%的top-1准确度。
在这里插入图片描述

转移到其他分类任务我们评估我们在其他分类数据集上的表现,以评估在 ImageNet (IN)上学习的特征是否是通用的,从而在图像域中是有用的,或者它们是否是特定于 ImageNet (IN)的。我们对[8,61]中使用的同一组分类任务进行线性评估和微调,并仔细遵循其评估方案,详见附录d。使用每个基准的标准度量来报告性能,并且在验证集的超参数选择之后,在保留的测试集上提供结果。我们在表3中报告了线性评估和微调的结果。在所有基准测试中,BYOL的表现都优于SimCLR,在12个基准测试中,有7个基准测试的表现优于监督输入基准测试,其余5个基准测试的表现仅略差。BYOL的表现可以转移到小图像,例如,CIFAR [65],风景,例如,SUN397 [66]或VOC2007 [67],和纹理,例如,DTD [68]。
在这里插入图片描述

转移到其他视觉任务我们评估我们在与计算机视觉从业者相关的不同任务上的表现,即语义分割、目标检测和深度估计。通过这一评估,我们评估了BYOL的表现是否超越了分类任务。

我们首先根据附录D.4中详述的2012年词汇语义分割任务对BYOL进行评估,目标是对图像中的每个像素进行分类[7]。我们在表4a中报告了结果。BYOL的表现优于监督输入Supervised-IN baseline(+1.9 mIoU)和SimCLR(+1.1 mIoU)。

类似地,我们通过使用Faster R-CNN架构[69]复制[9]中的设置来评估目标检测,详见附录D.5。我们对trainval2007进行了微调,并使用标准AP50指标报告了test2007的结果;BYOL明显好于Supervised-IN baseline (+3.1 AP50) and SimCLR (+2.3 AP50)。

最后,我们对NYU v2数据集的深度估计进行评估,其中场景的深度图是在给定单个RGB图像的情况下估计的。深度预测衡量的是一个网络在多大程度上代表了几何图形,以及该信息在多大程度上可以定位到像素精度[36]。该设置基于[70],详见附录D.6。我们对654幅图像的常用测试子集进行评估,并使用表4b中的几种常用指标报告结果:相对(rel)误差、均方根(rms)误差和像素百分比(pct),其中最大误差(dgt/dp,dp/dgt)低于1.25n阈值,其中dps为预测深度,dgt为地面真实深度[36]。在每一项指标上,BYOL都优于或等同于其他方法。例如,具有挑战性的pct.< 1.25的测量值与监督基线和SimCLR基线相比分别提高了+3.5和+1.3个点。
在这里插入图片描述

5 用消融建立直觉

我们在BYOL上展示消融术,以给出其行为和表现的直觉。为了再现性,我们在三个seeds上运行每个参数配置,并报告平均性能。当它大于0.25时,我们还报告了最佳和最差运行之间的一半差异。尽管以前的研究在100个epoch进行消融[8,12],我们注意到,100个epoch的相对进步并不总是能持续更长时间的训练。因此,我们在64个TPUv3核心上进行了超过300个epoch的消融,与1000个epoch的baseline训练相比,产生了一致的结果。对于本节中的所有实验,我们将初始学习速率设置为0.3,batch为4096,权重衰减为10的负6次方,与SimCLR [8]一致,基本目标衰减速率τbase设置为0.99。在本节中,我们报告了在附录C.1中的线性评估协议下,ImageNet的top-1准确度的结果。

Batch size在对比方法中,从minibatch中提取负面示例的方法在批量减少时性能会下降。BYOL不使用负面的例子,我们希望它对minibatch更具鲁棒性。为了实证地验证这一假设,我们使用从128到4096的不同批处理大小训练BYOL和SimCLR。为了避免重新调整其他超参数,我们在更新在线网络之前将梯度平均为连续N步,同时将批处理大小减少N个因子。在线网络更新后,每N步更新一次目标网络;我们在运行过程中并行地累积N个步骤。

如图3a所示,SimCLR的性能随着批次大小而迅速恶化,这可能是由于负面示例的数量减少。相比之下,BYOL的性能在从256到4096的大批量范围内保持稳定,并且由于编码器中的批量标准化层,只有较小的值会下降。(对最后这句话的注释:“在我们的训练过程中,唯一依赖于批处理大小的是批处理规范化层。”)
在这里插入图片描述

图像增强对比方法对图像增强的选择很敏感。例如,当从图像增强中消除颜色失真时,SimCLR不能很好地工作。作为一种解释,SimCLR表明,同一幅图像的裁剪大多共享它们的颜色直方图。同时,不同图像的颜色直方图也不同。因此,当对比任务仅依赖于随机裁剪作为图像增强时,它可以通过只关注颜色直方图来解决。结果,该表示没有被激励来保留颜色直方图之外的信息。为了防止这种情况,SimCLR在其图像增强设置中增加了颜色失真。取而代之的是,BYOL被激励将目标表示捕捉到的任何信息保留在其在线网络中,以改进其预测。因此,即使相同图像的增强视图共享相同的颜色直方图,BYOL仍然被激励在其表示中保留附加特征。因此,我们认为BYOL对图像增强的选择比对比方法更为有力。

图3b中的结果支持这一假设:当从图像增强集中消除颜色失真时,BYOL的性能比SimCLR的性能受影响小得多(BYOL为9.1个精度点,SimCLR为22.2个精度点)。当图像增强被简化为仅仅是随机裁剪时,BYOL仍然表现良好(59.4%,即从72.5%下降了13.1%),而SimCLR损失了三分之一以上的性能(40.3%,即从67.9%下降了27.6%)。我们在附录F.3中报告了附加消融。

自展 BYOL使用目标网络的投影表示,其权值是在线网络权值的指数移动平均值,作为其预测的目标。这样,目标网络的权值就代表了在线网络权值的一个延迟和更稳定的版本。当目标衰减率为1时,目标网络不进行更新,保持初始化时对应的常数值。当目标衰减率为0时,每一步目标网络都立即更新为在线网络。在更新目标过于频繁和太慢之间存在权衡,如表5a所示。瞬时更新目标网络(τ = 0)会破坏训练的稳定性,导致训练的性能非常差;而不更新目标网络(τ = 1)会使训练稳定,但会阻碍迭代改进,最终得到质量较低的最终表示。衰减率在0.9和0.999之间的所有值在300个epoch产生的性能高于68.4%的top-1精度。
在这里插入图片描述

对比方法的消融在这一小节中,我们使用相同的形式体系重新描述了SimCLR和BYOL,以更好地理解BYOL相对于SimCLR的改进来自哪里。让我们考虑以下扩展了InfoNCE目标[10]的目标:
在这里插入图片描述

其中α > 0是固定温度参数,β ∈ [0,1]是加权系数,B是批量大小,v和v0是批量的增强视图,其中对于任何批量索引 ii,vi和vi0是来自同一图像的增强视图;实值函数Sθ量化了增强视图之间的成对相似性。对于任何增广视图u,我们表示zθ(u),fθ(gθ(u))和zξ(u),fξ(gξ(u))
。对于给定的φ和ψ,我们考虑归一化的点积
在这里插入图片描述
对于次要细节(参见附录F.5),我们利用φ(u1) = zθ(u1)(无预测器)、ψ(u2) = zθ(u2)(无目标网络)和β = 1来恢复SimCLR损耗。当使用预测器和目标网络,即φ(u1) = pθ(zθ(u1))和ψ(u2) = zξ(u2)且β = 0时,我们恢复BYOL损耗。为了评估目标网络、预测器和系数β的影响,我们对它们执行消融。结果见表5b,更多细节见附录F.4。

唯一一个在没有负样本(即β = 0)的情况下表现良好的变体是BYOL,它使用自举目标网络和预测器。在不重新调整温度参数的情况下,将负对加到BYOL损耗上会损害其性能。在附录F.4中,我们展示了我们可以将负对加回来,并且仍然通过适当的温度调整来匹配BYOL的性能。

简单地向SimCLR添加一个目标网络已经提高了性能(+1.6个百分点)。这为在MoCo使用目标网络提供了新的线索[9],在那里目标网络被用来提供更多负样本。在这里,我们表明,仅仅通过稳定效应,即使使用相同数量的负样本,使用目标网络是有益的。最后,我们观察到修改sθ的架构以包含一个预测器只会轻微地影响SimCLR的性能。

6 结论

我们介绍了BYOL,一种新的自监督图像表示学习算法。BYOL通过预测其输出的先前版本来学习它的表现,而不使用负对。我们表明,BYOL在各种基准上取得了最先进的成果。特别是,在带有ResNet-50 (1×)的ImageNet上的线性评估协议下,BYOL实现了新的技术水平,并弥合了自我监督方法和监督学习baseline之间的大部分剩余差距[8]。使用ResNet-200 (2X),BYOL达到了79.6%的top-1精度,这比以前的技术水平(76.8%)有所提高,同时使用的参数减少了30%。

然而,BYOL仍然依赖于现有的视觉应用专用的增强设备。为了将BYOL推广到其他模态(例如,音频、视频、文本、.。。)有必要为它们中的每一个获得类似的合适的增强。设计这样的增强可能需要大量的努力和专业知识。因此,自动搜索这些增强将是将BYOL推广到其他模式的重要的下一步。

更广泛的影响
提出的研究应该被归类为无监督学习领域的研究。这项工作可能会启发新的算法,理论和实验研究。这里提出的算法可以用于许多不同的视觉应用,特定的应用可能有正面或负面的影响,这就是所谓的双重用途问题。此外,由于视觉数据集可能有偏差,BYOL所学的表征可能容易重复这些偏差。

以下是使用 BYOL(Bootstrap Your Own Latent)算法训练 CIFAR-10 数据集并绘制 t-SNE 图的示例代码: 首先,确保已安装必要的库,如 pytorch、torchvision、numpy 和 sklearn。然后,按照以下步骤进行操作: ```python import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms import numpy as np from sklearn.manifold import TSNE import matplotlib.pyplot as plt # 设置随机种子以确保实验的可复现性 torch.manual_seed(0) np.random.seed(0) # 加载 CIFAR-10 数据集 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True) # 定义 BYOL 网络模型(示例) class BYOLNet(nn.Module): def __init__(self): super(BYOLNet, self).__init__() # 定义网络结构,这里仅作示例,你可以根据需要自定义网络结构 self.encoder = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(), nn.Linear(32 * 8 * 8, 128), nn.ReLU() ) def forward(self, x): return self.encoder(x) # 定义 BYOL 训练函数 def train_byol(model, dataloader, optimizer, device): model.train() for data, _ in dataloader: data = data.to(device) optimizer.zero_grad() output = model(data) loss = torch.mean(output) # 示例损失函数,你可以根据需要修改 loss.backward() optimizer.step() # 创建 BYOL 模型实例 model = BYOLNet() # 定义优化器和设备 optimizer = optim.Adam(model.parameters(), lr=0.001) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 将模型移动到设备上 model.to(device) # BYOL 训练循环 num_epochs = 10 for epoch in range(num_epochs): train_byol(model, train_loader, optimizer, device) # 获取训练集的特征向量 features = [] labels = [] for data, target in train_loader: data = data.to(device) output = model.encoder(data).detach().cpu().numpy() features.extend(output) labels.extend(target.numpy()) # 使用 t-SNE 进行降维 tsne = TSNE(n_components=2) features_tsne = tsne.fit_transform(features) # 绘制 t-SNE 图 plt.scatter(features_tsne[:, 0], features_tsne[:, 1], c=labels, cmap='tab10') plt.colorbar() plt.show() ``` 这段代码会训练 BYOL 模型使用 CIFAR-10 数据集,并使用 t-SNE 算法将训练集的特征向量降维为二维,并将其可视化在散点图上。你可以根据需要自定义 BYOL 网络模型、损失函数、优化器等。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值