Exponential Moving Average Normalization(EMAN)论文和pytorch代码实现

一、论文

与标准 BN 不同的是,标准 BN 的统计数据是在每个batch内计算的,而 EMAN 是通过student模型的 BN 统计数据指数移动平均来更新其统计数据的。 

在teacher网络中,其参数取自的student模型参数的指数移动平均,但一个batch中的 BN 统计数据是在当前迭代中即时收集的,这可能会导致模型参数与参数空间中的 BN 统计数据不匹配。

二、代码

作者只对Moco部分的代码进行了研究,因此主要讲解EMAN源码中和原本Moco不一样的地方,对于Moco的源码讲解可以查看博客:Moco代码精读

1. moco.py文件(模型)

MoCoEMAN 继承自Moco模型,并重写了_momentum_update_key_encoder方法

和原模型的区别在于,BN层模型参数(当参数键值包含"num_batches_tracked"时)不会进行动量更新。("num_batches_tracked"参数的含义我还不是特别清楚,希望知道的uu可以评论一下)

class MoCoEMAN(MoCo):

    def __init__(self, base_encoder, dim=128, K=65536, m=0.999, 
                 T=0.07, num_mlp=2, norm_layer=None):
        super(MoCoEMAN, self).__init__(base_encoder, dim, 
                                       K, m, T, num_mlp, norm_layer)
        self.do_shuffle_bn = False

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder's state_dict. 
        In MoCo, it is parameters
        """
        state_dict_q = self.encoder_q.state_dict()
        state_dict_k = self.encoder_k.state_dict()
        for (k_q, v_q), (k_k, v_k) in zip(state_dict_q.items(), 
                                          state_dict_k.items()):
            assert k_k == k_q, "state_dict names are different!"
            if 'num_batches_tracked' in k_k:
                v_k.copy_(v_q)
            else:
                v_k.copy_(v_k * self.m + (1. - self.m) * v_q)

2. main_moco.py文件

2.1 参数新增warmup_epoch,也就是在刚开始的若干个epoch学习率从很小的值逐渐增大

# mainworker函数中
if epoch >= args.warmup_epoch:
    lr_schedule.adjust_learning_rate_with_min(optimizer, epoch, args)

# train函数中
# warmup learning rate
    if epoch < args.warmup_epoch:
        warmup_step = args.warmup_epoch * len(train_loader)
        curr_step = epoch * len(train_loader) + i + 1
        lr_schedule.warmup_learning_rate(optimizer, curr_step, warmup_step, args)
    curr_lr.update(optimizer.param_groups[0]['lr'])

warmup_learning_rate方法如下,也就是学习率会从一个比较小的值逐渐增大到设定的lr,因为这里args.lr是一直不变的,而scalar会随着epoch的增加逐渐增大到1.

def warmup_learning_rate(optimizer, curr_step, warmup_step, args):
    """linearly warm up learning rate"""
    lr = args.lr
    scalar = float(curr_step) / float(max(1, warmup_step))
    scalar = min(1., max(0., scalar))
    lr *= scalar
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

 2.2 新增 KNN evaluation,使用最近邻方法对训练结果初步进行评估

从训练集和验证集中分别取出一部分数据作为base和query数据集,使用当前训练好的模型得到其映射特征,再使用最近邻(KNN)的方法使用base数据集去预测query数据集中图片的标签,并且统计正确率。

val_loader_base = torch.utils.data.DataLoader(
    datasets.ImageFolderWithPercent(
        traindir,
        data_transforms.get_transforms("DefaultVal"),
        percent=args.nn_mem_percent
    ),
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)

val_loader_query = torch.utils.data.DataLoader(
    datasets.ImageFolderWithPercent(
        valdir,
        data_transforms.get_transforms("DefaultVal"),
        percent=args.nn_query_percent
    ),
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)
def ss_validate(val_loader_base, val_loader_query, model, args):
    print("start KNN evaluation with key size={} and query size={}".format(
        len(val_loader_base.dataset.targets), len(val_loader_query.dataset.targets)))
    batch_time_key = utils.AverageMeter('Time', ':6.3f')
    batch_time_query = utils.AverageMeter('Time', ':6.3f')
    # switch to evaluate mode
    model.eval()

    feats_base = []
    target_base = []
    feats_query = []
    target_query = []

    with torch.no_grad():
        start = time.time()
        end = time.time()
        # Memory features
        for i, (images, target) in enumerate(val_loader_base):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute features
            feats = model(images)
            # L2 normalization
            feats = nn.functional.normalize(feats, dim=1)

            feats_base.append(feats)
            target_base.append(target)

            # measure elapsed time
            batch_time_key.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Extracting key features: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
                    i, len(val_loader_base), batch_time=batch_time_key))

        end = time.time()
        for i, (images, target) in enumerate(val_loader_query):
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute features
            feats = model(images)
            # L2 normalization
            feats = nn.functional.normalize(feats, dim=1)

            feats_query.append(feats)
            target_query.append(target)

            # measure elapsed time
            batch_time_query.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Extracting query features: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
                    i, len(val_loader_query), batch_time=batch_time_query))

        feats_base = torch.cat(feats_base, dim=0)
        target_base = torch.cat(target_base, dim=0)
        feats_query = torch.cat(feats_query, dim=0)
        target_query = torch.cat(target_query, dim=0)
        feats_base = feats_base.detach().cpu().numpy()
        target_base = target_base.detach().cpu().numpy()
        feats_query = feats_query.detach().cpu().numpy()
        target_query = target_query.detach().cpu().numpy()
        feat_time = time.time() - start

        # KNN search
        index = faiss.IndexFlatL2(feats_base.shape[1])
        index.add(feats_base)
        D, I = index.search(feats_query, args.num_nn)
        preds = np.array([np.bincount(target_base[n]).argmax() for n in I])

        NN_acc = (preds == target_query).sum() / len(target_query) * 100.0
        knn_time = time.time() - start - feat_time
        print("finished KNN evaluation, feature time: {}, knn time: {}".format(
            timedelta(seconds=feat_time), timedelta(seconds=knn_time)))
        print(' * NN Acc@1 {:.3f}'.format(NN_acc))

    return NN_acc

  • 7
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值