[pytorch]FixMatch代码详解(超详细)

本文详细解析了FixMatch训练过程中的关键参数设置,包括weight decay(权重衰减)、学习率衰减、EMA模型应用,以及如何通过数据扩增和选择性伪标签提升模型性能。通过实例展示了训练参数和数据处理步骤,以及训练结果的变化趋势。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上一篇大概讲了数据加载的过程,这一篇更进一步,分析一下训练是怎样进行的
上一篇链接: [pytorch]FixMatch代码详解-数据加载

思维导图如下链接,非常详细的写出了代码的整体框架
思维导图

参数 default parameters

数据集链接
4000个带标签的数据集,也就是每个类400张带标签的数据

所有的参数我都默认使用作者给出的例子:

python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/cifar10@4000.5

其运行时每个参数的值如下:

INFO - __main__ -   {'T': 1, 'amp': False, 'arch': 'wideresnet', 'batch_size': 64, 'dataset': 'cifar10', 'device': device(type='cuda', index=0), 'ema_decay': 0.999, 'eval_step': 1024, 'expand_labels': True, 'gpu_id': 0, 'lambda_u': 1, 'local_rank': -1, 'lr': 0.03, 'mu': 7, 'n_gpu': 1, 'nesterov': True, 'no_progress': False, 'num_labeled': 4000, 'num_workers': 4, 'opt_level': 'O1', 'out': 'results/cifar10@4000.5', 'resume': '', 'seed': 5, 'start_epoch': 0, 'threshold': 0.95, 'total_steps': 1048576, 'use_ema': True, 'warmup': 0, 'wdecay': 0.0005, 'world_size': 1}

然后我们将这些参数带入,看看每一步是怎样运行的.

数据产生 generate data

首先,是产生带标签和不带标签数据的索引,其在cifar.py文件中的代码分析见上篇

base_dataset = datasets.CIFAR10(
        './CIFAR10', train=True, download=True)
labels = base_dataset.targets
label_per_class = 4000 // 10
labels = np.array(labels)
labeled_idx = []
# unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
unlabeled_idx = np.array(range(len(labels)))
for i in range(10):
    idx = np.where(labels == i)[0]
    idx = np.random.choice(idx, label_per_class, False)
    labeled_idx.extend(idx)
labeled_idx = np.array(labeled_idx)
print('number labeled_idx =',len(labeled_idx))
assert len(labeled_idx) == 4000

if True or 4000 < 64:
    num_expand_x = math.ceil(
        64 * 1024 / 4000)  #16.384 = 17
    labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
np.random.shuffle(labeled_idx)
print('number labeled_idx = ',len(labeled_idx))
print('number unlabeled_idx =', len(unlabeled_idx))
train_labeled_idxs = labeled_idx
train_unlabeled_idxs = unlabeled_idx

结果如下,不带标签的数据使用了所有的数据,而带标签的数据经过数据扩增之后为68000个

number labeled_idx = 4000
number labeled_idx =  68000
number unlabeled_idx = 50000

让我们看一下图片的变化
首先,是不带任何变化的原始数据图像:

train_labeled_dataset = CIFAR10SSL(
        './data', train_labeled_idxs, train=True,
        transform=transforms.ToTensor())
train_iter = iter(train_labeled_dataset)
# 可视化方法,重复执行可得到不同的图片数据
imgs, label = next(train_iter)
print(image.size) # (32, 32)
image = transforms.ToPILImage()(imgs).convert('RGB')
image.show()
print(label)

在这里插入图片描述
然后,我们使用不带数据增强的变化,也就是作者对验证集使用的图像变化. ToTensor()能够把灰度范围从0-255变换到0-1之间,而后面的transform.Normalize()则把0-1变换到(-1,1). 注意图片大小没有变化,只是我截图的时候放大了图片.

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)])
train_labeled_dataset = CIFAR10SSL(
        './data', train_labeled_idxs, train=True,
        transform=transform_val)
train_iter = iter(train_labeled_dataset)

imgs, label = next(train_iter)
print(image.size) # (32, 32)
image = transforms.ToPILImage()(imgs).convert('RGB')
image.show()
print(label)

在这里插入图片描述
然后我们看看带数据的图片所使用的数据增强(两次)

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
transform_labeled = transforms.Compose([
    transforms.RandomHorizontalFlip(), #Horizontally flip the given image randomly with a given probability.
    transforms.RandomCrop(size=32,
                          padding=int(32*0.125),
                          padding_mode='reflect'),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])
train_labeled_dataset = CIFAR10SSL(
        './data', train_labeled_idxs, train=True,
        transform=transform_labeled)
train_iter = iter(train_labeled_dataset)

imgs, label = next(train_iter)
print(image.size) # (32, 32)
image = transforms.ToPILImage()(imgs).convert('RGB')
image.show()
print(label) # 2 

在这里插入图片描述在这里插入图片描述

对于不带数据的标签,我们有两种数据增强,弱增强和强增强. 强增强操作在论文中的描述.
在这里插入图片描述
在这里插入图片描述

class TransformFixMatch(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect')])
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)
# 强增强的操作。在randaugment.py文件中
def fixmatch_augment_pool():
    # FixMatch paper
    augs = [(AutoContrast, None, None),
            (Brightness, 0.9, 0.05),
            (Color, 0.9, 0.05),
            (Contrast, 0.9, 0.05),
            (Equalize, None, None),
            (Identity, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 0.9, 0.05),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (TranslateX, 0.3, 0),
            (TranslateY, 0.3, 0)]
    return augs
    
class RandAugmentMC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = fixmatch_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() < 0.5:
                img = op(img, v=v, max_v=max_v, bias=bias)
        img = CutoutAbs(img, int(32*0.5))
        return img
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_labeled_dataset = CIFAR10SSL(
        './data', train_labeled_idxs, train=True,
        transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
train_iter = iter(train_labeled_dataset)

(inputs_u_w, inputs_u_s), _ = next(train_iter)
print(inputs_u_s.size) # (32, 32)
image = transforms.ToPILImage()(inputs_u_s).convert('RGB')
image.show()

弱增强的图像结果(两次):
在这里插入图片描述在这里插入图片描述
强增强的结果(运行四次):

在这里插入图片描述在这里插入图片描述在这里插入图片描述在这里插入图片描述
所以,产生的带标签/不带标签/验证集的dataset类及dataloader如下:

labeled_dataset = CIFAR10SSL(
    './data', train_labeled_idxs, train=True,
    transform=transform_labeled)
# len = 68000
unlabeled_dataset = CIFAR10SSL(
    './data', train_unlabeled_idxs, train=True,
    transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
# len = 50000
test_dataset = datasets.CIFAR10(
    './data', train=False, transform=transform_val, download=False)
# len = 10000
train_sampler = RandomSampler
labeled_trainloader = DataLoader(
        labeled_dataset,
        sampler=train_sampler(labeled_dataset),
        batch_size=64,
        num_workers=4,
        drop_last=True)
# len = 6800/64 = 1062.5 (drop_last=True) = 1062 
unlabeled_trainloader = DataLoader(
    unlabeled_dataset,
    sampler=train_sampler(unlabeled_dataset),
    batch_size=64*7, # mu coefficient of unlabeled batch size 原文中的超参数μ 
    num_workers=4,
    drop_last=True)
# len = 50000/(64*7) = 111
test_loader = DataLoader(
    test_dataset,
    sampler=SequentialSampler(test_dataset),
    batch_size=64,
    num_workers=7)
# len = 10000/64 = 156.25(drop_last=False)= 157 

构建模型 Build the model

def create_model():
    import models.wideresnet as models
    model = models.build_wideresnet(depth=28,
                                    widen_factor=2,
                                    dropout=0,
                                    num_classes=10)
    return model
    
model = create_model()
#print(model)
#for p in model.parameters():
#    print(p.numel())
total_num = sum(p.numel() for p in model.parameters())
print(total_num) # 1467610 模型总参数

训练参数设置 Training parameter settings

在参数设置时,有许多模型训练的tricks. 我简单的说一下他们的设置. 这里是作者的一些结论.
在这里插入图片描述
在这里插入图片描述

weight decay(权值衰减)

weight decay(权值衰减)其目的是防止过拟合。在损失函数中,weight decay是放在正则项(regularization)前面的一个系数,正则项一般指示模型的复杂度,所以weight decay的作用是调节模型复杂度对损失函数的影响,若weight decay很大,则复杂的模型损失函数的值也就大。
同时,作者也提到了要使用SGD优化器。

# weight decay default=5e-4
no_decay = ['bias', 'bn']
grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(
            nd in n for nd in no_decay)], 'weight_decay': 5e-4},
        {'params': [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = optim.SGD(grouped_parameters, lr=0.03,
                      momentum=0.9, nesterov=True)

除了 bias和bn层,其他层使用weight decay.
在这里插入图片描述
在这里插入图片描述

学习率衰减(learning rate decay)

正如作者在原文中提到的,对于学习率调整,我们使用余弦学习率衰减. 同时还加上了Warmup操作. 学习率一开始很小,在到达设定的num_warmup_steps前,学习率慢慢增大,最后达到设定的学习率的值。之后,使用余弦学习率衰减,其公式如上面原文中提到的。

def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

scheduler = get_cosine_schedule_with_warmup(optimizer, 0, 2**20)

学习率是神经网络训练中最重要的超参数之一,针对学习率的优化方式很多,Warmup是其中的一种
(一)、什么是Warmup?
Warmup是在ResNet论文中提到的一种学习率预热的方法,它在训练开始的时候先选择使用一个较小的学习率,训练了一些epoches或者steps(比如4个epoches,10000steps),再修改为预先设置的学习来进行训练。

(二)、为什么使用Warmup?
由于刚开始训练时,模型的权重(weights)是随机初始化的,此时若选择一个较大的学习率,可能带来模型的不稳定(振荡),选择Warmup预热学习率的方式,可以使得开始训练的几个epoches或者一些steps内学习率较小,在预热的小学习率下,模型可以慢慢趋于稳定,等模型相对稳定后再选择预先设置的学习率进行训练,使得模型收敛速度变得更快,模型效果更佳。

ExampleExample:Resnet论文中使用一个110层的ResNet在cifar10上训练时,先用0.01的学习率训练直到训练误差低于80%(大概训练了400个steps),然后使用0.1的学习率进行训练。

自定义调整:自定义调整学习率 LambdaLR。
torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)
在这里插入图片描述

指数移动平均(EMA)model

This algorithm is one of the most important algorithms currently in usage. From financial time series, signal processing to neural networks , it is being used quite extensively. Basically any data that is in a sequence.
We mostly use this algorithm to reduce the noise in noisy time-series data. The term we use for this is called “smoothing” the data.
The way we achieve this is by essentially weighing the number of observations and using their average. This is called as Moving Average.
In deep learning, the EMA (Exponential Moving Average) method is often used to average the parameters of the model in order to improve the test index and increase the robustness of the model.
在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
这个技巧我也不是很懂,可以看别人的文章介绍: 【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现

训练过程 training process

在这里插入图片描述

都写在注释中了,每一步的过程很清楚

# 准备
epochs = math.ceil(2**20/ 1024) #1024 总epoch
start_epoch = 0
test_accs = []
end = time.time() #返回当前时间的时间戳
def interleave(x, size):
    s = list(x.shape)
    return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])
def de_interleave(x, size):
    s = list(x.shape)
    return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])
labeled_iter = iter(labeled_trainloader)
unlabeled_iter = iter(unlabeled_trainloader)
model.train()
for epoch in range(start_epoch, epochs):
    #batch_time = AverageMeter()#它仅用于计算和存储一些统计信息,例如关于损失的统计信息。
    #data_time = AverageMeter()
    #losses = AverageMeter()
    #losses_x = AverageMeter()
    #losses_u = AverageMeter()
    #mask_probs = AverageMeter()
    p_bar = tqdm(range(1024))
    for batch_idx in range(1024):
        
        # 使用iter(next)读取指定次数的batch,而不通过Dataloader。Dataloader的长度也不同。
        try:
            inputs_x, targets_x = labeled_iter.next()
            #print(inputs_x.shape) # torch.Size([64, 3, 32, 32])
            #print(targets_x.shape) # torch.Size([64])
            #print(targets_x)
        except:  # 当循环结束时,重新开始循环
            labeled_iter = iter(labeled_trainloader)
            inputs_x, targets_x = labeled_iter.next()
        try:
            (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()
            #print(inputs_u_w.shape) #torch.Size([448, 3, 32, 32])
            #print(inputs_u_s.shape) #torch.Size([448, 3, 32, 32])
        except:
            unlabeled_iter = iter(unlabeled_trainloader)
            (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()
        # print(time.time() - end) # data_time = 200秒左右 读取一组数据的时间
        
        
        batch_size = inputs_x.shape[0] #64
        new_data = interleave(
                torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2*7+1) #'mu': 7
        # print(new_data.shape) torch.Size([960, 3, 32, 32]) 448+448+64 64*(2*7+1) 将数据合并一起
        inputs = new_data.to(device)
        targets_x = targets_x.to(device)
        
        
        logits = model(inputs)
        #print(logits.shape) #torch.Size([960, 10])
        logits = de_interleave(logits, 2*7+1)
        #print(logits.shape) #torch.Size([960, 10])
        logits_x = logits[:batch_size]
        #print(logits_x.shape) #torch.Size([64, 10])
        logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
        #print(logits_u_w.shape) #torch.Size([448, 10]) 
        
        #通过weak_augment样本计算伪标记pseudo label和mask,
        #其中,mask用来筛选哪些样本最大预测概率超过阈值,可以拿来使用,哪些不能使用
        
        Lx = F.cross_entropy(logits_x, targets_x, reduction='mean') #带标签数据的loss
        #print(Lx) # tensor(2.6575, device='cuda:0', grad_fn=<NllLossBackward0>)
        pseudo_label = torch.softmax(logits_u_w.detach()/1, dim=-1) #输出变成概率
        # pseudo label temperature = 1 原来的softmax函数是T = 1的特例。 
        # T越高,softmax的output probability distribution越趋于平滑,其分布的熵越大,
        # 负标签携带的信息会被相对地放大,模型训练将更加关注负标签。
        max_probs, targets_u = torch.max(pseudo_label, dim=-1)
        #print(max_probs.shape) # torch.Size([448]) 448个最大概率值
        #print(targets_u.shape) # torch.Size([448]) 448个伪标签的值
        #print(targets_u) #tensor([3, 5, 1 ....], device='cuda:0')
        mask = max_probs.ge(0.95).float() #'threshold': 0.95
        # torch.ge(a,b)逐个元素比较a,b的大小
        # print(mask.shape) #torch.Size([448]) 448个0/1
        # print(F.cross_entropy(logits_u_s, targets_u,reduction='none')) # reduction='none'不求平均,返回448个值
        Lu = (F.cross_entropy(logits_u_s, targets_u,
                                  reduction='none') * mask).mean() #不带标签数据的loss,其中通过mask进行样本筛选
        #print(Lu) #tensor(0., device='cuda:0', grad_fn=<MeanBackward0>)
                
        loss = Lx + 1 * Lu # 'lambda_u': 1 #完整损失函数
        
        
        print(time.time() - end) # 3439秒 batch_time 计算完一组数据的时间
        end = time.time() #为下一轮做准备
        print(mask.mean().item()) # mask_probs = mask的均值 代表超过threshold的个数比例

运行结果 result

首先先来看一下程序的运行结果
开头

(torch) liyihao@liyihao-Precision-5820-Tower:~/LI/FixMatch-pytorch-master$ python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/test1
02/16/2022 17:48:12 - WARNING - __main__ -   Process rank: -1, device: cuda:0, n_gpu: 1, distributed training: False, 16-bits training: False
02/16/2022 17:48:12 - INFO - __main__ -   {'T': 1, 'amp': False, 'arch': 'wideresnet', 'batch_size': 64, 'dataset': 'cifar10', 'device': device(type='cuda', index=0), 'ema_decay': 0.999, 'eval_step': 1024, 'expand_labels': True, 'gpu_id': 0, 'lambda_u': 1, 'local_rank': -1, 'lr': 0.03, 'mu': 7, 'n_gpu': 1, 'nesterov': True, 'no_progress': False, 'num_labeled': 4000, 'num_workers': 4, 'opt_level': 'O1', 'out': 'results/test1', 'resume': '', 'seed': 5, 'start_epoch': 0, 'threshold': 0.95, 'total_steps': 1048576, 'use_ema': True, 'warmup': 0, 'wdecay': 0.0005, 'world_size': 1}
Files already downloaded and verified
02/16/2022 17:48:14 - INFO - models.wideresnet -   Model: WideResNet 28x2
02/16/2022 17:48:14 - INFO - __main__ -   Total params: 1.47M
02/16/2022 17:48:18 - INFO - __main__ -   ***** Running training *****
02/16/2022 17:48:18 - INFO - __main__ -     Task = cifar10@4000
02/16/2022 17:48:18 - INFO - __main__ -     Num Epochs = 1024
02/16/2022 17:48:18 - INFO - __main__ -     Batch size per GPU = 64
02/16/2022 17:48:18 - INFO - __main__ -     Total train batch size = 64
02/16/2022 17:48:18 - INFO - __main__ -     Total optimization steps = 1048576
Train Epoch: 1/1024. Iter: 1024/1024. LR: 0.0300. Data: 0.045s. Batch: 0.207s. Loss: 1.2336. Loss_x: 1.1920. Loss_u: 0.0416. Mask: 0.07. : 100%|| 102
Test Iter:  157/ 157. Data: 0.005s. Batch: 0.012s. Loss: 1.8805. top1: 31.71. top5: 81.31. : 100%|██████████████████| 157/157 [00:02<00:00, 77.27it/s]
02/16/2022 17:51:52 - INFO - __main__ -   top-1 acc: 31.71
02/16/2022 17:51:52 - INFO - __main__ -   top-5 acc: 81.31
02/16/2022 17:51:52 - INFO - __main__ -   Best top-1 acc: 31.71
02/16/2022 17:51:52 - INFO - __main__ -   Mean top-1 acc: 31.71

Train Epoch: 2/1024. Iter: 1024/1024. LR: 0.0300. Data: 0.046s. Batch: 0.206s. Loss: 0.7871. Loss_x: 0.6212. Loss_u: 0.1659. Mask: 0.31. : 100%|| 102
Test Iter:  157/ 157. Data: 0.005s. Batch: 0.012s. Loss: 0.9442. top1: 66.99. top5: 97.58. : 100%|██████████████████| 157/157 [00:01<00:00, 80.80it/s]
02/16/2022 17:55:22 - INFO - __main__ -   top-1 acc: 66.99
02/16/2022 17:55:22 - INFO - __main__ -   top-5 acc: 97.58
02/16/2022 17:55:22 - INFO - __main__ -   Best top-1 acc: 66.99
02/16/2022 17:55:22 - INFO - __main__ -   Mean top-1 acc: 49.35

Train Epoch: 3/1024. Iter: 1024/1024. LR: 0.0300. Data: 0.045s. Batch: 0.206s. Loss: 0.5908. Loss_x: 0.3215. Loss_u: 0.2692. Mask: 0.50. : 100%|| 102
Test Iter:  157/ 157. Data: 0.005s. Batch: 0.012s. Loss: 0.6990. top1: 75.80. top5: 98.54. : 100%|██████████████████| 157/157 [00:02<00:00, 77.19it/s]
02/16/2022 17:58:53 - INFO - __main__ -   top-1 acc: 75.80
02/16/2022 17:58:53 - INFO - __main__ -   top-5 acc: 98.54
02/16/2022 17:58:54 - INFO - __main__ -   Best top-1 acc: 75.80
02/16/2022 17:58:54 - INFO - __main__ -   Mean top-1 acc: 58.17

运行了100多个epoch之后

Train Epoch: 150/1024. Iter: 1024/1024. LR: 0.0294. Data: 0.017s. Batch: 0.157s. Loss: 0.2174. Loss_x: 0.0090. Loss_u: 0.2084. Mask: 0.90. : 100%|| 1
Test Iter:  157/ 157. Data: 0.004s. Batch: 0.009s. Loss: 0.2418. top1: 94.17. top5: 99.87. : 100%|██████████████████| 157/157 [00:01<00:00, 99.77it/s]
02/17/2022 00:54:18 - INFO - __main__ -   top-1 acc: 94.17
02/17/2022 00:54:18 - INFO - __main__ -   top-5 acc: 99.87
02/17/2022 00:54:18 - INFO - __main__ -   Best top-1 acc: 94.28
02/17/2022 00:54:18 - INFO - __main__ -   Mean top-1 acc: 94.03

Train Epoch: 151/1024. Iter: 1024/1024. LR: 0.0294. Data: 0.018s. Batch: 0.158s. Loss: 0.2118. Loss_x: 0.0066. Loss_u: 0.2052. Mask: 0.90. : 100%|| 1
Test Iter:  157/ 157. Data: 0.004s. Batch: 0.010s. Loss: 0.2393. top1: 94.37. top5: 99.91. : 100%|██████████████████| 157/157 [00:01<00:00, 89.00it/s]
02/17/2022 00:57:00 - INFO - __main__ -   top-1 acc: 94.37
02/17/2022 00:57:00 - INFO - __main__ -   top-5 acc: 99.91
02/17/2022 00:57:00 - INFO - __main__ -   Best top-1 acc: 94.37
02/17/2022 00:57:00 - INFO - __main__ -   Mean top-1 acc: 94.05

Train Epoch: 152/1024. Iter: 1024/1024. LR: 0.0294. Data: 0.017s. Batch: 0.158s. Loss: 0.2209. Loss_x: 0.0097. Loss_u: 0.2113. Mask: 0.90. : 100%|| 1
Test Iter:  157/ 157. Data: 0.004s. Batch: 0.009s. Loss: 0.2414. top1: 94.19. top5: 99.86. : 100%|█████████████████| 157/157 [00:01<00:00, 100.27it/s]
02/17/2022 00:59:41 - INFO - __main__ -   top-1 acc: 94.19
02/17/2022 00:59:41 - INFO - __main__ -   top-5 acc: 99.86
02/17/2022 00:59:41 - INFO - __main__ -   Best top-1 acc: 94.37
02/17/2022 00:59:41 - INFO - __main__ -   Mean top-1 acc: 94.06

Train Epoch: 153/1024. Iter: 1024/1024. LR: 0.0294. Data: 0.017s. Batch: 0.159s. Loss: 0.2210. Loss_x: 0.0110. Loss_u: 0.2100. Mask: 0.90. : 100%|| 1
Test Iter:  157/ 157. Data: 0.003s. Batch: 0.009s. Loss: 0.2439. top1: 94.07. top5: 99.87. : 100%|█████████████████| 157/157 [00:01<00:00, 101.27it/s]
02/17/2022 01:02:24 - INFO - __main__ -   top-1 acc: 94.07
02/17/2022 01:02:24 - INFO - __main__ -   top-5 acc: 99.87
02/17/2022 01:02:24 - INFO - __main__ -   Best top-1 acc: 94.37
02/17/2022 01:02:24 - INFO - __main__ -   Mean top-1 acc: 94.06

tensorboard看看各个参数的变化
验证集的参数变化
在这里插入图片描述
训练集上的参数变化
在这里插入图片描述

在这里插入图片描述

### FixMatch 半监督学习 PyTorch 实现及教程 #### 什么是FixMatchFixMatch是一种高效的半监督学习算法,在处理少量标注数据和大量未标注数据的情况下表现出色。该方法的核心思想是对同一图像应用不同的增强方式,然后强制模型在这两种视图上给出一致的预测结果。具体来说,对于每一个未标注的数据点,先对其进行弱增强(weak augmentation),再进行强增强(strong augmentation)。只有当强增强后的图片预测概率超过某个设定阈值时,才将其视为有效伪标签参与训练。 #### 主要组件解析 - **数据加载器**:构建支持混合批处理模式的数据管道,能够同时提供带标签样本以及不带标签样本给网络。 - **一致性正则化项**:引入额外损失函数来衡量同一个实例不同变换版本之间输出分布的距离,鼓励它们尽可能接近。 - **阈值机制**:设置一个置信度下限,过滤掉那些不确定性强增广后仍无法获得高可信度分类的结果。 #### 完整代码示例 下面是一个基于PyTorch框架下的简化版FixMatch实现: ```python import torch from torchvision import transforms, datasets from torch.utils.data.dataloader import DataLoader class TransformTwice: def __init__(self, weak_transforms, strong_transforms): self.weak = weak_transforms self.strong = strong_transforms def __call__(self, img): out0 = self.weak(img) out1 = self.strong(img) return out0, out1 def get_cifar10(): transform_train_weak = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) transform_train_strong = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.5), transforms.ToTensor() ]) trainset_labeled = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train_weak) trainset_unlabeled = datasets.CIFAR10(root='./data', train=True, download=True, transform=TransformTwice(transform_train_weak, transform_train_strong)) testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor()) labeled_loader = DataLoader(trainset_labeled, batch_size=64, shuffle=True) unlabeled_loader = DataLoader(trainset_unlabeled, batch_size=64*7, shuffle=True) test_loader = DataLoader(testset, batch_size=100, shuffle=False) return labeled_loader, unlabeled_loader, test_loader if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' model = ... # Define your neural network architecture here. optimizer = ... scheduler = ... labeled_loader, unlabeled_loader, _ = get_cifar10() for epoch in range(num_epochs): model.train() for (images_x, targets_x), ((images_uw, images_us), _) in zip(labeled_loader, unlabeled_loader): logits_x = model(images_x.to(device)) with torch.no_grad(): outputs_uw = model(images_uw.to(device)) p = torch.softmax(outputs_uw, dim=-1).max(1)[0].detach().cpu().numpy() mask = p >= threshold loss_x = F.cross_entropy(logits_x, targets_x.to(device)) if sum(mask)>0: selected_indices = np.where(mask==True)[0] pseudo_labels = torch.argmax(F.softmax(model(images_us[selected_indices].to(device)),dim=-1),axis=-1) loss_u = F.cross_entropy(logits_u[selected_indices],pseudo_labels) else: loss_u = torch.tensor([0]).float().mean() total_loss = loss_x + lambda_u * loss_u optimizer.zero_grad() total_loss.backward() optimizer.step() ``` 此段代码展示了如何创建适用于CIFAR-10数据集上的FixMatch训练流程,包括定义必要的转换操作、获取数据集迭代器以及编写核心训练循环等内容[^5]。
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值