SimCLR图像分类——pytorch复现

该博客介绍了SimCLR框架在图像分类任务中的实现过程,包括无监督学习阶段和有监督学习阶段。无监督阶段使用ResNet50提取特征,去除池化和全连接层,通过全连接、批次标准化和ReLU激活得到输出特征。有监督阶段则加载无监督阶段的特征提取层,仅训练全连接层。整个过程涉及随机图像增强、对比学习损失函数和交叉熵损失函数。代码展示了网络模型、训练流程和损失函数计算。此外,还提供了配置文件、数据加载、训练过程的可视化以及验证集评估和自定义图片测试的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、网络模型、损失函数

1.原理

SimCLR(A Simple Framework for Contrastive Learning of Visual Representations)是一种对比学习网络,可以对含有少量标签的数据集进行训练推理,它包含无监督学习和有监督学习两个部分。

无监督学习网络特征提取采用resnet50,将输入层进行更改,并去掉池化层及全连接层。之后将特征图平坦化,并依次进行全连接、批次标准化、relu激活、全连接,得到输出特征。

有监督学习网络使用无监督学习网络的特征提取层及参数,之后由一个全连接层得到分类输出。

在第一阶段先进行无监督学习,对输入图像进行两次随机图像增强,即由一幅图像得到两个随机处理过后的图像,依次放入网络进行训练,计算损失并更新梯度。
在这里插入图片描述
这一阶段损失函数为:
在这里插入图片描述
其中,x+为与x相似的样本,x-为与x不相似的样本。

第二阶段,加载第一阶段的特征提取层训练参数,用少量带标签样本进行有监督学习(只训练全连接层)。这一阶段损失函数为交叉熵损失函数CrossEntropyLoss。

2.code

# net.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet50


# stage one ,unsupervised learning
class SimCLRStage1(nn.Module):
    def __init__(self, feature_dim=128):
        super(SimCLRStage1, self).__init__()

        self.f = []
        for name, module in resnet50().named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        # encoder
        self.f = nn.Sequential(*self.f)
        # projection head
        self.g = nn.Sequential(nn.Linear(2048, 512, bias=False),
                               nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True),
                               nn.Linear(512, feature_dim, bias=True))

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)


# stage two ,supervised learning
class SimCLRStage2(torch.nn.Module):
    def __init__(self, num_class):
        super(SimCLRStage2, self).__init__()
        # encoder
        self.f = SimCLRStage1().f
        # classifier
        self.fc = nn.Linear(2048, num_class, bias=True)

        for param in self.f.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.fc(feature)
        return out


class Loss(torch.nn.Module):
    def __init__(self):
        super(Loss,self).__init__()

    def forward(self,out_1,out_2,batch_size,temperature=0.5):
        # 分母 :X.X.T,再去掉对角线值,分析结果一行,可以看成它与除了这行外的其他行都进行了点积运算(包括out_1和out_2),
        # 而每一行为一个batch的一个取值,即一个输入图像的特征表示,
        # 因此,X.X.T,再去掉对角线值表示,每个输入图像的特征与其所有输出特征(包括out_1和out_2)的点积,用点积来衡量相似性
        # 加上exp操作,该操作实际计算了分母
        # [2*B, D]
        out = torch.cat([out_1, out_2], dim=0)
        # [2*B, 2*B]
        sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
        mask = (torch.ones_like(sim_matrix) - torch.eye(2 * batch_size, device=sim_matrix.device)).bool()
        # [2*B, 2*B-1]
        sim_matrix = sim_matrix.masked_select(mask).view(2 * batch_size, -1)

        # 分子: *为对应位置相乘,也是点积
        # compute loss
        pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
        # [2*B]
        pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
        return (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()


if __name__=="__main__":
    for name, module in resnet50().named_children():
        print(name,module)

二、配置文件

公共参数写入配置文件

# config.py
import os
from torchvision import transforms

use_gpu=True
gpu_name=1

pre_model=os.path.join('pth','model.pth')

save_path="pth"

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

三、无监督学习数据加载

使用CIFAR-10数据集,一共包含10个类别的RGB彩色图片:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿(deer)、狗(dog)、蛙类(frog)、马(horse)、船(ship)和卡车(truck)。图片的尺寸为32×32,数据集中一共有50000张训练图片片和10000张测试图片。

# loaddataset.py
from torchvision.datasets import CIFAR10
from PIL import Image


class PreDataset(CIFAR10):
    def __getitem__(self, item):
        img,target=self.data[item],self.targets[item]
        img = Image.fromarray(img)

        if self.transform is not None:
            imgL = self.transform(img)
            imgR = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return imgL, imgR, target


if __name__=="__main__":

    import config
    train_data = PreDataset(root='dataset', train=True, transform=config.train_transform, download=True)
    print(train_data[0])

四、无监督训练

# trainstage1.py
import torch,argparse,os
import net,config,loaddataset


# train stage one
def train(args):
    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(config.gpu_name))
        # 每次训练计算图改动较小使用,在开始前选取较优的基础算法(比如选择一种当前高效的卷积算法)
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    train_dataset=loaddataset.PreDataset(root='dataset', train=True, transform=config.train_transform, download=True)
    train_data=torch.utils.data.DataLoader(train_dataset,batch_size=args.batch_size, shuffle=True, num_workers=16 , drop_last=True)

    model =net.SimCLRStage1().to(DEVICE)
    lossLR=net.Loss().to(DEVICE)
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

    os.makedirs(config.save_path, exist_ok=True)
    for epoch in range(1,args.max_epoch+1):
        model.train()
        total_loss = 0
        for batch,(imgL,imgR,labels) in enumerate(train_data):
            imgL,imgR,labels=imgL.to(DEVICE),imgR.to(DEVICE),labels.to(DEVICE)

            _, pre_L=model(imgL)
            _, pre_R=model(imgR)

            loss=lossLR(pre_L,pre_R,args.batch_size)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print("epoch", epoch, "batch", batch, "loss:", loss.detach().item())
            total_loss += loss.detach().item()

        print("epoch loss:",total_loss/len(train_dataset)*args.batch_size)

        with open(os.path.join(config.save_path, "stage1_loss.txt"), "a") as f:
            f.write(str(total_loss/len(train_dataset)*args.batch_size) + " ")

        if epoch % 5==0:
            torch.save(model.state_dict(), os.path.join(config.save_path, 'model_stage1_epoch' + str(epoch) + '.pth'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train SimCLR')
    parser.add_argument('--batch_size', default=200, type=int, help='')
    parser.add_argument('--max_epoch', default=1000, type=int, help='')

    args = parser.parse_args()
    train(args)

五、有监督训练

# trainstage2.py
import torch,argparse,os
import net,config
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader


# train stage two
def train(args):
    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(2))   #config.gpu_name
        # 每次训练计算图改动较小使用,在开始前选取较优的基础算法(比如选择一种当前高效的卷积算法)
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    # load dataset for train and eval
    train_dataset = CIFAR10(root='dataset', train=True, transform=config.train_transform, download=True)
    train_data = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True)
    eval_dataset = CIFAR10(root='dataset', train=False, transform=config.test_transform, download=True)
    eval_data = DataLoader(eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True)

    model =net.SimCLRStage2(num_class=len(train_dataset.classes)).to(DEVICE)
    model.load_state_dict(torch.load(args.pre_model, map_location='cpu'),strict=False)
    loss_criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6)

    os.makedirs(config.save_path, exist_ok=True)
    for epoch in range(1,args.max_epoch+1):
        model.train()
        total_loss=0
        for batch, (data, target) in enumerate(train_data):
            data, target = data.to(DEVICE), target.to(DEVICE)
            pred = model(data)

            loss = loss_criterion(pred, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print("epoch",epoch,"loss:", total_loss / len(train_dataset)*args.batch_size)
        with open(os.path.join(config.save_path, "stage2_loss.txt"), "a") as f:
            f.write(str(total_loss / len(train_dataset)*args.batch_size) + " ")

        if epoch % 5==0:
            torch.save(model.state_dict(), os.path.join(config.save_path, 'model_stage2_epoch' + str(epoch) + '.pth'))

            model.eval()
            with torch.no_grad():
                print("batch", " " * 1, "top1 acc", " " * 1, "top5 acc")
                total_loss, total_correct_1, total_correct_5, total_num = 0.0, 0.0, 0.0, 0
                for batch, (data, target) in enumerate(train_data):
                    data, target = data.to(DEVICE), target.to(DEVICE)
                    pred = model(data)

                    total_num += data.size(0)
                    prediction = torch.argsort(pred, dim=-1, descending=True)
                    top1_acc = torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                    top5_acc = torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
                    total_correct_1 += top1_acc
                    total_correct_5 += top5_acc

                    print("  {:02}  ".format(batch + 1), " {:02.3f}%  ".format(top1_acc / data.size(0) * 100),
                          "{:02.3f}%  ".format(top5_acc / data.size(0) * 100))

                print("all eval dataset:", "top1 acc: {:02.3f}%".format(total_correct_1 / total_num * 100),
                          "top5 acc:{:02.3f}%".format(total_correct_5 / total_num * 100))
                with open(os.path.join(config.save_path, "stage2_top1_acc.txt"), "a") as f:
                    f.write(str(total_correct_1 / total_num * 100) + " ")
                with open(os.path.join(config.save_path, "stage2_top5_acc.txt"), "a") as f:
                    f.write(str(total_correct_5 / total_num * 100) + " ")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train SimCLR')
    parser.add_argument('--batch_size', default=200, type=int, help='')
    parser.add_argument('--max_epoch', default=200, type=int, help='')
    parser.add_argument('--pre_model', default=config.pre_model, type=str, help='')

    args = parser.parse_args()
    train(args)

六、训练并查看过程

使用visdom,对训练过程保存的loss、acc进行可视化
由于时间关系,只训练了较少的epoch

# showbyvisdom.py
import numpy as np
import visdom


def show_loss(path, name, step=1):
    with open(path, "r") as f:
        data = f.read()
    data = data.split(" ")[:-1]
    x = np.linspace(1, len(data) + 1, len(data)) * step
    y = []
    for i in range(len(data)):
        y.append(float(data[i]))

    vis = visdom.Visdom(env='loss')
    vis.line(X=x, Y=y, win=name, opts={'title': name, "xlabel": "epoch", "ylabel": name})


def compare2(path_1, path_2, title="xxx", legends=["a", "b"], x="epoch", step=20):
    with open(path_1, "r") as f:
        data_1 = f.read()
    data_1 = data_1.split(" ")[:-1]

    with open(path_2, "r") as f:
        data_2 = f.read()
    data_2 = data_2.split(" ")[:-1]

    x = np.linspace(1, len(data_1) + 1, len(data_1)) * step
    y = []
    for i in range(len(data_1)):
        y.append([float(data_1[i]), float(data_2[i])])

    vis = visdom.Visdom(env='loss')
    vis.line(X=x, Y=y, win="compare",
             opts={"title": "compare " + title, "legend": legends, "xlabel": "epoch", "ylabel": title})


if __name__ == "__main__":
    show_loss("stage1_loss.txt", "loss1")
    show_loss("stage2_loss.txt", "loss2")
    show_loss("stage2_top1_acc.txt", "acc1")
    show_loss("stage2_top5_acc.txt", "acc1")

    # compare2("precision1.txt", "precision2.txt", title="precision", step=20)

无监督学习损失变化曲线:
在这里插入图片描述
有监督学习损失变化曲线
在这里插入图片描述

七、验证集评估

# eval.py
import torch,argparse
from torchvision.datasets import CIFAR10
import net,config


def eval(args):
    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(config.gpu_name))
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")

    eval_dataset=CIFAR10(root='dataset', train=False, transform=config.test_transform, download=True)
    eval_data=torch.utils.data.DataLoader(eval_dataset,batch_size=args.batch_size, shuffle=False, num_workers=16, )

    model=net.SimCLRStage2(num_class=len(eval_dataset.classes)).to(DEVICE)
    model.load_state_dict(torch.load(config.pre_model, map_location='cpu'), strict=False)

    # total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(eval_data)
    total_correct_1, total_correct_5, total_num = 0.0, 0.0, 0.0

    model.eval()
    with torch.no_grad():
        print("batch", " "*1, "top1 acc", " "*1,"top5 acc" )
        for batch, (data, target) in enumerate(eval_data):
            data, target = data.to(DEVICE) ,target.to(DEVICE)
            pred=model(data)

            total_num += data.size(0)
            prediction = torch.argsort(pred, dim=-1, descending=True)
            top1_acc = torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            top5_acc = torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_correct_1 += top1_acc
            total_correct_5 += top5_acc

            print("  {:02}  ".format(batch+1)," {:02.3f}%  ".format(top1_acc / data.size(0) * 100),"{:02.3f}%  ".format(top5_acc / data.size(0) * 100))

        print("all eval dataset:","top1 acc: {:02.3f}%".format(total_correct_1 / total_num * 100), "top5 acc:{:02.3f}%".format(total_correct_5 / total_num * 100))



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='test SimCLR')
    parser.add_argument('--batch_size', default=512, type=int, help='')

    args = parser.parse_args()
    eval(args)

在这里插入图片描述

八、自定义图片测试

# test.py
import torch,argparse
import net,config
from torchvision.datasets import CIFAR10
import cv2


def show_CIFAR10(index):
    eval_dataset=CIFAR10(root='dataset', train=False, download=False)
    print(eval_dataset.__len__())
    print(eval_dataset.class_to_idx,eval_dataset.classes)
    img, target=eval_dataset[index][0], eval_dataset[index][1]

    import matplotlib.pyplot as plt
    plt.figure(str(target))
    plt.imshow(img)
    plt.show()


def test(args):
    classes={'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
    index2class=[x  for x in classes.keys()]
    print("calss:",index2class)

    if torch.cuda.is_available() and config.use_gpu:
        DEVICE = torch.device("cuda:" + str(config.gpu_name))
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")

    transform = config.test_transform

    ori_img=cv2.imread(args.img_path,1)
    img=cv2.resize(ori_img,(32,32)) # evry important,influence the result

    img=transform(img).unsqueeze(dim=0).to(DEVICE)

    model=net.SimCLRStage2(num_class=10).to(DEVICE)
    model.load_state_dict(torch.load(args.pre_model, map_location='cpu'), strict=False)

    pred = model(img)

    prediction = torch.argsort(pred, dim=-1, descending=True)

    label=index2class[prediction[:, 0:1].item()]
    cv2.putText(ori_img,"this is "+label,(30,30),cv2.FONT_HERSHEY_DUPLEX,1, (0,255,0), 1)
    cv2.imshow(label,ori_img)
    cv2.waitKey(0)


if __name__ == '__main__':
    # show_CIFAR10(2)

    parser = argparse.ArgumentParser(description='test SimCLR')
    parser.add_argument('--pre_model', default=config.pre_model, type=str, help='')
    parser.add_argument('--img_path', default="bird.jpg", type=str, help='')

    args = parser.parse_args()
    test(args)

输入图片:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
输出:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

<think>嗯,用户想复现SimCLR模型,用PyTorch或TensorFlow实现,需要教程和代码。首先,我得回忆一下SimCLR的基本结构。SimCLR是基于对比学习的自监督方法,主要步骤包括数据增强、编码器、投影头、对比损失这些部分。用户提到PyTorch和TensorFlow,可能需要分别给出两种框架的实现方法? 先看用户提供的引用资料。引用[1]和[2]提到了SimCLRv2的PyTorch实现和安装步骤,引用[4]是TensorFlow的代码示例。用户可能希望看到两种框架的代码示例,或者其中一种。但用户的问题里同时提到了PyTorch和TensorFlow,所以可能需要分别回答。 首先,SimCLR的关键点包括: 1. 数据增强:随机裁剪、颜色抖动、灰度化等,生成正样本对。 2. 编码器网络:通常是ResNet,提取特征。 3. 投影头:将编码后的特征映射到对比损失空间,通常是一个MLP。 4. NT-Xent损失函数:归一化温度标度的交叉熵损失。 接下来,针对PyTorch和TensorFlow的实现,需要分别给出步骤。例如,在PyTorch中,可能需要定义数据增强模块,构建编码器(如ResNet),投影头,以及实现对比损失。对于数据加载部分,可能需要自定义Dataset类,应用不同的增强方法。 在引用中,用户提供了PyTorch的安装和项目地址,还有TensorFlow的代码片段。例如,引用[4]中的编码器结构是使用Keras的Sequential模型,包含多个卷积层。这可能是一个简化的版本,实际SimCLR的编码器通常是更深的网络,比如ResNet-50。 用户可能需要更详细的步骤,比如如何构建数据增强、如何定义损失函数、训练循环等。此外,半监督学习的部分可能需要用预训练的编码器进行微调,但用户的问题主要是复现SimCLR模型本身,所以可能集中在预训练阶段。 然后,需要注意引用资料中的信息。例如,引用[1]的PyTorch实现项目可能包含预训练模型,引用[2]提到安装tensorflow-gpu和下载预训练模型。引用[3]介绍PyTorch的特点,可能在PyTorch实现时需要用到相关模块,比如torch.nn、数据处理等。 在回答时,可能需要分步骤,先数据准备,再模型构建,然后损失函数,最后训练过程。每个步骤给出代码片段,并指出引用的来源。例如,PyTorch部分可以参考引用[1]的项目,TensorFlow部分可以参考引用[4]的代码结构。 同时,用户可能需要知道如何组织代码,例如定义数据增强的方式,使用现有的预训练模型作为编码器,或者自己构建。例如,在PyTorch中,可以使用torchvision的transforms进行数据增强,而编码器可以使用torchvision.models.resnet50,去掉最后的全连接层,添加投影头。 另外,对比损失的计算需要注意正负样本对的构造,以及温度参数的应用。在实现时,需要计算相似度矩阵,并应用logits和标签的交叉熵。 最后,用户可能需要相关的教程链接或开源项目参考,如引用[1]的GitCode项目,或者引用[4]的代码示例。同时,可能需要提醒安装必要的库,如PyTorch或TensorFlow,以及依赖项。 总结回答结构: 1. 复现步骤概述 2. PyTorch实现的关键代码(数据增强、编码器、投影头、损失函数) 3. TensorFlow实现的关键代码(类似结构) 4. 训练注意事项(引用相关资源) 5. 相关问题部分,引导用户进一步思考 需要确保代码语法正确,符合用户提到的LaTeX格式要求,例如行内公式用$...$,独立公式用$$...$$。但这个问题可能不涉及太多数学公式,但如果有的话需要正确应用。例如,对比损失的公式可能需要用LaTeX呈现。 可能的问题:用户可能混淆了SimCLRSimCLRv2,但引用中提到的项目是SimCLRv2,但用户问的是SimCLR。需要确认两者的区别,但可能回答时统一处理,或者指出差异。 综上,回答需要分框架给出实现步骤,引用提供的资源,并确保代码正确性,同时生成相关问题。</think>### SimCLR模型复现方法 SimCLR(Simple Framework for Contrastive Learning of Visual Representations)是一种基于对比学习的自监督方法,复现需重点关注数据增强、编码器设计、投影头结构和对比损失实现。以下是PyTorch和TensorFlow的实现要点: --- #### **1. PyTorch实现** **数据增强** 使用`torchvision.transforms`生成正样本对: ```python import torchvision.transforms as transforms augmentation = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(0.8, 0.8, 0.8, 0.2), transforms.GaussianBlur(kernel_size=9), transforms.ToTensor() ]) ``` **编码器与投影头** 选择ResNet作为编码器,添加MLP投影头: ```python import torch.nn as nn from torchvision.models import resnet50 class SimCLR(nn.Module): def __init__(self, hidden_dim=2048, projection_dim=128): super().__init__() self.encoder = resnet50(pretrained=False) self.encoder.fc = nn.Identity() # 移除原全连接层 self.projection = nn.Sequential( nn.Linear(2048, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, projection_dim) ) def forward(self, x): h = self.encoder(x) return self.projection(h) ``` **对比损失(NT-Xent)** 实现归一化温度标度交叉熵损失[^1]: ```python import torch import torch.nn.functional as F def nt_xent_loss(z1, z2, temperature=0.5): z = torch.cat([z1, z2], dim=0) sim_matrix = F.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=-1) / temperature mask = (~torch.eye(2*z1.size(0), dtype=torch.bool)).float() logits = sim_matrix * mask labels = torch.arange(z1.size(0), device=z1.device).repeat(2) return F.cross_entropy(logits, labels) ``` --- #### **2. TensorFlow实现** **编码器结构** 参考SimCLR官方设计[^4]: ```python import tensorflow as tf from tensorflow.keras import layers def get_encoder(): base_model = tf.keras.applications.ResNet50(include_top=False, weights=None) inputs = tf.keras.Input(shape=(224, 224, 3)) x = base_model(inputs) x = layers.GlobalAveragePooling2D()(x) return tf.keras.Model(inputs, x) def get_projection_head(input_dim=2048): return tf.keras.Sequential([ layers.Dense(2048, activation="relu"), layers.Dense(128) ]) ``` **训练流程** ```python # 数据增强 augment = tf.keras.Sequential([ layers.experimental.preprocessing.RandomFlip("horizontal"), layers.experimental.preprocessing.RandomContrast(0.8) ]) # 损失函数 def contrastive_loss(z1, z2, temperature=0.5): z = tf.concat([z1, z2], axis=0) sim_matrix = tf.matmul(z, z, transpose_b=True) / temperature mask = tf.eye(tf.shape(z)[0], dtype=tf.bool) logits = tf.where(mask, -tf.ones_like(sim_matrix), sim_matrix) labels = tf.range(tf.shape(z1)[0]) labels = tf.concat([labels, labels], axis=0) return tf.nn.sparse_softmax_cross_entropy_with_logits(labels, logits) ``` --- #### **训练注意事项** - **批量大小**:需使用大批次(如4096)以提升对比学习效果。 - **学习率调度**:采用余弦衰减学习率(如初始值0.3)[^2]。 - **半监督微调**:预训练后冻结编码器,添加分类层进行有监督训练。 参考开源项目: - PyTorch实现:[SimCLRv2-Pytorch](https://gitcode.com/gh_mirrors/si/SimCLRv2-Pytorch) [^1] - TensorFlow教程:[SimCLR半监督分类](https://keras.io/examples/vision/semisupervised_simclr/) ---
评论 200
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值