半监督:FixMatch

总结下fixmatch总体做了什么:

  • 计算目标域训练集中熵值小于熵值均值的样本所占的比例PS
  • 筛选目标域训练集中每个类中熵值最小的PS比例的数据集作为新源域,目标域训练集经过网络得到的predict作为label
  • 筛选目标域训练集中每个类中剩下的数据集作为新目标域,目标域训练集经过网络得到的predict作为label
  • 新目标域上两batch input:inputs_t和inputs_t2分别经过网络得到output1和output2,将output1和2混合:p=(output1+output2)/2,之后pt=p**1/T,pt的每一行除以该行的和得到targets_u,作为无标签目标域的target。即: p t = ( o u t p u t 1 + o u t p u t 2 2 ) 1 / T pt = (\frac{output1+output2}{2})^{1/T} pt=(2output1+output2)1/T t a r g e t s _ u = p t p t . s u m ( d i m = 1 ) targets\_u = \frac{pt}{pt.sum(dim=1)} targets_u=pt.sum(dim=1)pt
  • [inputs_s, inputs_t, inputs_t2]作为input,[targets_s, targets_u, targets_u]作为target
  • 将打乱和未打乱的input混合得到mixed_input,同理得到mixed_target(未打乱的占主要地位,所以还可以认为是[源域,目标域,目标域]的组合)
  • 交错mixed_input中有标签和无标签的样本(源域和目标域)(再次交错可还原)
  • 将mixed_input输入到网络中得到logits,之后将logits重新交错(还原)
  • 根据mixed_input中的有标签和无标签样本以及logits来计算semiloss
  • 反向传播训练网络

GPT总结:
FixMatch 是一种用于半监督学习的方法,它结合了伪标签生成和一致性正则化来有效利用无标签数据。以下是 FixMatch 的关键步骤:

  1. 计算熵值:计算目标域训练集中熵值小于熵值均值的样本所占的比例 (PS)。
  2. 筛选数据:根据熵值筛选目标域训练集中每个类中熵值最小的 (PS) 比例的数据集作为新源域,并将目标域训练集经过网络得到的预测作为标签。
  3. 生成伪标签:筛选目标域训练集中每个类中剩下的数据集作为新目标域,并将目标域训练集经过网络得到的预测作为标签。
  4. 混合输出:在新目标域上,两批输入 inputs_tinputs_t2 分别经过网络得到 output1output2,将 output1output2 混合得到 (p = \frac{output1 + output2}{2}),然后通过温度缩放 (pt = p^{1/T}),并归一化得到 targets_u
  5. 组合输入和目标:将 [inputs_s, inputs_t, inputs_t2] 作为输入,将 [targets_s, targets_u, targets_u] 作为目标。
  6. 交错输入:将打乱和未打乱的输入混合得到 mixed_input,同理得到 mixed_target
  7. 交错样本:交错 mixed_input 中有标签和无标签的样本(源域和目标域),并在之后还原。
  8. 网络计算:将 mixed_input 输入到网络中得到 logits,之后将 logits 重新交错(还原)。
  9. 计算损失:根据 mixed_input 中的有标签和无标签样本以及 logits 来计算 semiloss
  10. 反向传播:反向传播训练网络。

既然交错input后还需将output也就是logit复原,那么交错的意义在哪里呢?

交错输入和输出的主要目的是为了在训练过程中保持批归一化(BatchNorm)的统计数据一致性。
交错源域和目标域的输入后,BatchNorm层能够正确计算统计数据的原因在于它能够更好地反映整个数据集的统计特性。具体来说:

  • 均匀分布样本:
    如果批次中的样本仅来自源域或目标域,BatchNorm层计算的均值和方差可能会偏向某一域的数据特性。
    通过交错源域和目标域的样本,可以确保每个批次中的样本更均匀地分布在两个域之间,从而使BatchNorm层计算的均值和方差更准确地反映整个数据集的统计特性。
  • 减少偏差:
    在半监督学习中,源域和目标域的数据分布可能存在差异。如果BatchNorm层仅基于单一域的数据进行计算,可能会导致模型在另一域上的表现不佳。
    交错输入可以减少这种偏差,使得BatchNorm层在计算统计数据时能够综合考虑两种数据分布,从而提高模型在不同域上的泛化能力。
  • 稳定训练过程:
    BatchNorm层的统计数据对模型的训练过程有重要影响。交错输入可以使得每个批次中的统计数据更加稳定,避免因单一域数据导致的统计波动,从而使训练过程更加稳定和高效。

训练函数代码(可对照上述步骤理解):


# 主函数中传入txt_src:目标域训练集中选择的熵值低于PS比例的样本, txt_tgt:目标域训练集中选择的熵值高于PS比例的样本
def train(args, txt_src, txt_tgt):
    ## set pre-process
    #     dset_loaders["train"]:目标域所有训练集
    #     dset_loaders["test"]:目标域所有测试集
    #     dset_loaders["source"]:目标域训练集中选择的熵值低于PS比例的样本
    #     dset_loaders["target"]:目标域训练集中选择的熵值高于PS比例的样本
    #     因为splitdata时没有涉及源域数据集,所以也可以理解为
    #     dset_loaders["train"]:训练集
    #     dset_loaders["test"]:测试集
    #     dset_loaders["source"]:训练集中选择的熵值低于PS比例的样本,作为源域,也即置信度比较高的部分,可以认为其伪标签就是真实的标签
    #     dset_loaders["target"]:训练集中选择的熵值高于PS比例的样本,作为目标域,也即置信度比较低的部分,需要mixmatch
    dset_loaders = data_load(args, txt_src, txt_tgt)
    # pdb.set_trace()
    max_len = max(len(dset_loaders["source"]), len(dset_loaders["target"]))
    max_iter = args.max_epoch*max_len
    interval_iter = max_iter // 10
    # netG:主体网络,特征提取网络
    if args.dset == 'u2m':
        netG = network.LeNetBase().to(device)
    elif args.dset == 'm2u':
        netG = network.LeNetBase().to(device)
    elif args.dset == 's2m':
        netG = network.DTNBase().to(device)
    # netB:bn,relu,dropout,linear
    netB = network.feat_bottleneck(type=args.classifier, feature_dim=netG.in_features, bottleneck_dim=args.bottleneck).to(device)
    # netC:linear(out_channels = 10)
    netC = network.feat_classifier(type=args.layer, class_num = args.class_num, bottleneck_dim=args.bottleneck).to(device)

    if args.model == 'source':
        modelpath = args.output_dir + "/source_F.pt" 
        netG.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/source_B.pt"   
        netB.load_state_dict(torch.load(modelpath))
    else:
        modelpath = args.output_dir + "/target_F_" + args.savename + ".pt" 
        netG.load_state_dict(torch.load(modelpath))
        modelpath = args.output_dir + "/target_B_" + args.savename + ".pt"   
        netB.load_state_dict(torch.load(modelpath))
    # netF:特征提取网络后的部分:bn,relu,dropout,linear
    netF = nn.Sequential(netB, netC)
    optimizer_g = optim.SGD(netG.parameters(), lr = args.lr * 0.1)
    optimizer_f = optim.SGD(netF.parameters(), lr = args.lr)
    # base_network:网络整体,包括特征提取和分类器
    base_network = nn.Sequential(netG, netF)
    source_loader_iter = iter(dset_loaders["source"])
    target_loader_iter = iter(dset_loaders["target"])

    list_acc = []
    best_ent = 100

    for iter_num in range(1, max_iter + 1):
        base_network.train()
        lr_scheduler(optimizer_g, init_lr=args.lr * 0.1, iter_num=iter_num, max_iter=max_iter)
        lr_scheduler(optimizer_f, init_lr=args.lr, iter_num=iter_num, max_iter=max_iter)
        # inputs_source:源域(即原目标域中置信度较高的部分)数据集的图像 labels_source:源域数据集的标签
        try:
            inputs_source, labels_source = next(source_loader_iter)
        except:
            source_loader_iter = iter(dset_loaders["source"])
            inputs_source, labels_source = next(source_loader_iter)
        # inputs_target:目标域(即原目标域中置信度较低的部分)数据集的图像 target_idx:该batch中目标域数据集的索引
        try:
            # inputs_target:一个列表,包含两个张量,每个张量的形状为(batch_size, 3, 32, 32)
            inputs_target, _, target_idx = next(target_loader_iter)
        except:
            target_loader_iter = iter(dset_loaders["target"])
            inputs_target, _, target_idx = next(target_loader_iter)
        # 这段代码将创建一个形状为 (args.batch_size, args.class_num) 的张量,其中每一行对应一个样本,每一行中只有一个位置为 1,其余位置为 0。这个位置由 labels_source 中的标签决定。
        # 这用于将源域的标签转换为 one-hot 编码形式,即targets_s是源域数据集的标签的one-hot编码形式
        targets_s = torch.zeros(args.batch_size, args.class_num).scatter_(1, labels_source.view(-1,1), 1)
        inputs_s = inputs_source.to(device)
        targets_s = targets_s.to(device)
        inputs_t = inputs_target[0].to(device)
        inputs_t2 = inputs_target[1].to(device)

        with torch.no_grad():
            # compute guessed labels of unlabel samples
            outputs_u = base_network(inputs_t)
            outputs_u2 = base_network(inputs_t2)
            p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2 # p:(36,10)
            pt = p**(1/args.T) # pt:p的1/T次方
            targets_u = pt / pt.sum(dim=1, keepdim=True) #targets_u:pt的每一行除以该行的和
            targets_u = targets_u.detach()

        # 将该batch的源域和目标域的数据集合并,inputs_s:源域数据集的图像 inputs_t:目标域数据集的图像 inputs_t2:目标域数据集的图像
        all_inputs = torch.cat([inputs_s, inputs_t, inputs_t2], dim=0)
        # targets_s:由源域数据集的标签转换为 one-hot 编码形式 targets_u:两input_target经过网络输出混合处理之后得到的标签
        all_targets = torch.cat([targets_s, targets_u, targets_u], dim=0)
        if args.alpha > 0:
            # l是一个从 Beta 分布中抽取的随机数,范围在 0.5 到 1 之间
            l = np.random.beta(args.alpha, args.alpha)
            l = max(l, 1-l)
        else:
            l = 1
        # idx即打乱后的inputs索引
        idx = torch.randperm(all_inputs.size(0))
        # inputs_a:未打乱的inputs inputs_b:打乱后的inputs
        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]
        # mixed_input:混合后的inputs(打乱和未打乱的inputs混合,因为l大于0.5所以未打乱的inputs占主要地位) mixed_target:混合后的targets
        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 
        mixed_input = list(torch.split(mixed_input, args.batch_size))
        # 交错有标签和无标签的样本
        mixed_input = utils.interleave(mixed_input, args.batch_size)  
        # s = [sa, sb, sc] = [(12,3,32,32), (12,3,32,32), (12,3,32,32)]
        # t1 = [t1a, t1b, t1c]
        # t2 = [t2a, t2b, t2c]
        # => s' = [sa, t1b, t2c]   t1' = [t1a, sb, t1c]   t2' = [t2a, t2b, sc]
        # 可以发现将s`和t1`、t2`再次interleave后,就是原来的s、t1、t2

        #logits为mixed_input经过网络后的输出
        logits = base_network(mixed_input[0])
        logits = [logits]
        for input in mixed_input[1:]:
            temp = base_network(input)
            logits.append(temp)

        # put interleaved samples back
        # [i[:,0] for i in aa]
        # 将s`和t1`、t2`再次interleave后,就是原来的s、t1、t2
        # 由于只有inputs进行了交错,target没有交错所以需要将output重新复位顺序
        # 将 logits 重新交错,以匹配输入的原始顺序。
        logits = utils.interleave(logits, args.batch_size)
        # logits_x:源域数据的输出(有标签) logits_u:目标域数据的输出(无标签)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)
        # 上述代码中的逻辑可以总结如下:将mixed_input中的源域数据与目标域数据交错,输入到网络中得到logits,再将logits重新交错得到原输入顺序对应的输出结果。
        # 这种方法确保了在训练过程中,BatchNorm层能够正确计算统计数据,同时也增强了模型的鲁棒性和泛化能力。
        # 交错源域和目标域的输入后,BatchNorm层能够正确计算统计数据的原因在于 它能够更好地反映整个数据集的统计特性。

        train_criterion = utils.SemiLoss()
        # mixed_target: target_s,target_t,target_t
        # 计算semiloss输入:源域数据输出,源域数据target,目标域数据输出,目标域数据target,当前迭代次数,最大迭代次数,参数lambda
        Lx, Lu, w = train_criterion(logits_x, mixed_target[:args.batch_size], logits_u, mixed_target[args.batch_size:], 
            iter_num, max_iter, args.lambda_u)
        loss = Lx + w * Lu

        optimizer_g.zero_grad()
        optimizer_f.zero_grad()
        loss.backward()
        optimizer_g.step()
        optimizer_f.step()

        if iter_num % interval_iter == 0 or iter_num == max_iter:
            base_network.eval()
            # accuracy, predict, all_output, all_label
            acc, py, score, y = cal_acc(dset_loaders["train"], base_network, flag=False)
            mean_ent = torch.mean(Entropy(score))
            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(args.dset + '_train', iter_num, max_iter, acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n') 


            acc, py, score, y = cal_acc(dset_loaders["test"], base_network, flag=False)
            mean_ent = torch.mean(Entropy(score))
            list_acc.append(acc)

            if best_ent > mean_ent:
                val_acc = acc
                best_ent = mean_ent
                best_y = y
                best_py = py
                best_score = score

            log_str = 'Task: {}, Iter:{}/{}; Accuracy = {:.2f}%; Mean Ent = {:.4f}'.format(args.dset + '_test', iter_num, max_iter, acc, mean_ent)
            args.out_file.write(log_str + '\n')
            args.out_file.flush()
            print(log_str+'\n')       

    idx = np.argmax(np.array(list_acc))
    max_acc = list_acc[idx]
    final_acc = list_acc[-1]

    log_str = '\n==========================================\n'
    log_str += '\nVal Acc = {:.2f}\nMax Acc = {:.2f}\nFin Acc = {:.2f}\n'.format(val_acc, max_acc, final_acc)
    args.out_file.write(log_str + '\n')
    args.out_file.flush()  

    # torch.save(base_network.state_dict(), osp.join(args.output_dir, args.log + ".pt"))
    # sio.savemat(osp.join(args.output_dir, args.log + ".mat"), {'y':best_y.cpu().numpy(), 
    #     'py':best_py.cpu().numpy(), 'score':best_score.cpu().numpy()})
    
    return base_network, py
  • 20
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值