FixMatch文章解读+算法流程+核心代码详解

FixMatch

本博客仅做算法流程疏导,具体细节请参见原文

原文

查看原文点这里

Github代码

Github代码点这里

解读

FixMatch算法抓住了半监督算法的两个重要观点,第一个是一致性正则化,第二个是伪标记。一致性正则化在MixMatch中已经介绍过了,在此不再赘述。伪标记是一种常用的半监督算法。

伪标记

伪标记(pseudo label)其实算最早的一类半监督算法,代表算法self-training。简单地说就是通过训练的模型对无标记样本打标签,这个标签有对有错,通过一些方法筛选标签后,选择一部分无标记样本和模型打的标签一起送入模型继续训练。伪标记的方法最大问题在于,如何保证伪标记的正确性。因为当模型打的标签提供了较多的错误信息时,会使模型的训练结果更劣。一般常见的筛选方式是将模型输出的预测结果( S o f t m a x Softmax Softmax之后)进行阈值判断,其 a r g m a x argmax argmax的概率大于阈值,才认为是有效标记,否则将此无标记样本丢弃。

整体算法

FixMatch算法并不复杂,结合一致性正则化和伪标记两种算法。由其论文中的流程图就可以很好的理解。

image-20210802101107656

对于有标记样本,进行正常的监督学习,损失函数为 C r o s s E n t r o p y L o s s CrossEntropyLoss CrossEntropyLoss​,得到 L s L_s Ls​。其公式表达如下:

L s = 1 B ∑ b = 1 B H ( p b , p m ( y ∣ α ( x b ) ) ) L_s=\frac{1}{B}\sum^B_{b=1}H(p_b,p_m(y|\alpha(x_b))) Ls=B1b=1BH(pb,pm(yα(xb)))

对于无标记样本,参照上图,共四步。

第一步,先对无标记样本进行扩增(Augment),扩增分为强扩增和弱扩增,弱扩增使用标准的旋转和移位;强扩增使用RandAugment和CTAugment两种算法。

第二步,对扩增后的样本进行预测。对于弱扩增的样本,输出的预测结果( S o f t m a x Softmax Softmax之后的)最高预测概率(即 a r g m a x argmax argmax的结果)大于阈值(图中的虚线),则认为是有效的样本,将其预测结果作为标签(这就是pseudo label)。

第三步:对强扩增的样本,输出的预测结果和对应弱标记样本得到的标签做 C r o s s E n t r o p y L o s s CrossEntropyLoss CrossEntropyLoss​,得到损失函数 L u L_u Lu​。其公式表达为:

L u = 1 μ B ∑ b = 1 μ B 1 ( m a x ( q b ) ≥ τ ) H ( q b ^ , p m ( y ∣ A ( u b ) ) ) L_u=\frac{1}{\mu B}\sum^{\mu B}_{b=1}\mathcal{1}(max(q_b)\geq \tau )H(\hat{q_b},p_m(y|\mathcal{A}(u_b))) Lu=μB1b=1μB1(max(qb)τ)H(qb^,pm(yA(ub)))

简而言之就是选择 m a x ( q b ) ≥ τ max(q_b)\geq \tau max(qb)τ H ( q b ^ , p m ( y ∣ A ( u b ) ) H(\hat{q_b},p_m(y|\mathcal{A}(u_b)) H(qb^,pm(yA(ub))作为 L u L_u Lu的组成成分,参与反向梯度传播更新。​

第四步:最终损失函数为 L o s s = L s + α L u Loss = L_s+\alpha L_u Loss=Ls+αLu α \alpha α是超参数。

L o s s Loss Loss反向梯度传播完成整个算法模型更新。

核心代码解读

这里读取一个batch的操作,和前一篇MixMatch的代码实现相同,为了读取指定次数的batch,而不通过Dataloader。

for batch_idx in range(args.eval_step):
    try:
        inputs_x, targets_x = labeled_iter.next()
    except:
        if args.world_size > 1:
            labeled_epoch += 1
            labeled_trainloader.sampler.set_epoch(labeled_epoch)
        labeled_iter = iter(labeled_trainloader)
        inputs_x, targets_x = labeled_iter.next()

    try:
        (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()
    except:
        if args.world_size > 1:
            unlabeled_epoch += 1
            unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)
        unlabeled_iter = iter(unlabeled_trainloader)
        (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()

得到strong_augment样本和weak_augment样本,分别为logits_u_slogits_u_w

logits = model(inputs)
logits = de_interleave(logits, 2*args.mu+1)
logits_x = logits[:batch_size]
logits_u_w, logits_u_s = logits[batch_size:].chunk(2)

对有标记样本做 C r o s s E n t r o p y L o s s CrossEntropyLoss CrossEntropyLoss

 Lx = F.cross_entropy(logits_x, targets_x, reduction='mean')

通过weak_augment样本计算伪标记pseudo labelmask,其中,mask用来筛选哪些样本最大预测概率超过阈值,可以拿来使用,哪些不能使用

pseudo_label = torch.softmax(logits_u_w.detach()/args.T, dim=-1)
max_probs, targets_u = torch.max(pseudo_label, dim=-1)
mask = max_probs.ge(args.threshold).float()

计算无标记样本的损失函数 L u L_u Lu,其中通过mask进行样本筛选

Lu = (F.cross_entropy(logits_u_s, targets_u, reduction='none') * mask).mean()

完整损失函数如下

loss = Lx + args.lambda_u * Lu

反向梯度更新,完成!~

  • 24
    点赞
  • 99
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值