一、论文
与标准 BN 不同的是,标准 BN 的统计数据是在每个batch内计算的,而 EMAN 是通过student模型的 BN 统计数据指数移动平均来更新其统计数据的。
在teacher网络中,其参数取自的student模型参数的指数移动平均,但一个batch中的 BN 统计数据是在当前迭代中即时收集的,这可能会导致模型参数与参数空间中的 BN 统计数据不匹配。
二、代码
作者只对Moco部分的代码进行了研究,因此主要讲解EMAN源码中和原本Moco不一样的地方,对于Moco的源码讲解可以查看博客:Moco代码精读
1. moco.py文件(模型)
MoCoEMAN 继承自Moco模型,并重写了_momentum_update_key_encoder方法
和原模型的区别在于,BN层模型参数(当参数键值包含"num_batches_tracked"时)不会进行动量更新。("num_batches_tracked"参数的含义我还不是特别清楚,希望知道的uu可以评论一下)
class MoCoEMAN(MoCo):
def __init__(self, base_encoder, dim=128, K=65536, m=0.999,
T=0.07, num_mlp=2, norm_layer=None):
super(MoCoEMAN, self).__init__(base_encoder, dim,
K, m, T, num_mlp, norm_layer)
self.do_shuffle_bn = False
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder's state_dict.
In MoCo, it is parameters
"""
state_dict_q = self.encoder_q.state_dict()
state_dict_k = self.encoder_k.state_dict()
for (k_q, v_q), (k_k, v_k) in zip(state_dict_q.items(),
state_dict_k.items()):
assert k_k == k_q, "state_dict names are different!"
if 'num_batches_tracked' in k_k:
v_k.copy_(v_q)
else:
v_k.copy_(v_k * self.m + (1. - self.m) * v_q)
2. main_moco.py文件
2.1 参数新增warmup_epoch,也就是在刚开始的若干个epoch学习率从很小的值逐渐增大
# mainworker函数中
if epoch >= args.warmup_epoch:
lr_schedule.adjust_learning_rate_with_min(optimizer, epoch, args)
# train函数中
# warmup learning rate
if epoch < args.warmup_epoch:
warmup_step = args.warmup_epoch * len(train_loader)
curr_step = epoch * len(train_loader) + i + 1
lr_schedule.warmup_learning_rate(optimizer, curr_step, warmup_step, args)
curr_lr.update(optimizer.param_groups[0]['lr'])
warmup_learning_rate方法如下,也就是学习率会从一个比较小的值逐渐增大到设定的lr,因为这里args.lr是一直不变的,而scalar会随着epoch的增加逐渐增大到1.
def warmup_learning_rate(optimizer, curr_step, warmup_step, args):
"""linearly warm up learning rate"""
lr = args.lr
scalar = float(curr_step) / float(max(1, warmup_step))
scalar = min(1., max(0., scalar))
lr *= scalar
for param_group in optimizer.param_groups:
param_group['lr'] = lr
2.2 新增 KNN evaluation,使用最近邻方法对训练结果初步进行评估
从训练集和验证集中分别取出一部分数据作为base和query数据集,使用当前训练好的模型得到其映射特征,再使用最近邻(KNN)的方法使用base数据集去预测query数据集中图片的标签,并且统计正确率。
val_loader_base = torch.utils.data.DataLoader(
datasets.ImageFolderWithPercent(
traindir,
data_transforms.get_transforms("DefaultVal"),
percent=args.nn_mem_percent
),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
val_loader_query = torch.utils.data.DataLoader(
datasets.ImageFolderWithPercent(
valdir,
data_transforms.get_transforms("DefaultVal"),
percent=args.nn_query_percent
),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
def ss_validate(val_loader_base, val_loader_query, model, args):
print("start KNN evaluation with key size={} and query size={}".format(
len(val_loader_base.dataset.targets), len(val_loader_query.dataset.targets)))
batch_time_key = utils.AverageMeter('Time', ':6.3f')
batch_time_query = utils.AverageMeter('Time', ':6.3f')
# switch to evaluate mode
model.eval()
feats_base = []
target_base = []
feats_query = []
target_query = []
with torch.no_grad():
start = time.time()
end = time.time()
# Memory features
for i, (images, target) in enumerate(val_loader_base):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
# compute features
feats = model(images)
# L2 normalization
feats = nn.functional.normalize(feats, dim=1)
feats_base.append(feats)
target_base.append(target)
# measure elapsed time
batch_time_key.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Extracting key features: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
i, len(val_loader_base), batch_time=batch_time_key))
end = time.time()
for i, (images, target) in enumerate(val_loader_query):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
# compute features
feats = model(images)
# L2 normalization
feats = nn.functional.normalize(feats, dim=1)
feats_query.append(feats)
target_query.append(target)
# measure elapsed time
batch_time_query.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
print('Extracting query features: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
i, len(val_loader_query), batch_time=batch_time_query))
feats_base = torch.cat(feats_base, dim=0)
target_base = torch.cat(target_base, dim=0)
feats_query = torch.cat(feats_query, dim=0)
target_query = torch.cat(target_query, dim=0)
feats_base = feats_base.detach().cpu().numpy()
target_base = target_base.detach().cpu().numpy()
feats_query = feats_query.detach().cpu().numpy()
target_query = target_query.detach().cpu().numpy()
feat_time = time.time() - start
# KNN search
index = faiss.IndexFlatL2(feats_base.shape[1])
index.add(feats_base)
D, I = index.search(feats_query, args.num_nn)
preds = np.array([np.bincount(target_base[n]).argmax() for n in I])
NN_acc = (preds == target_query).sum() / len(target_query) * 100.0
knn_time = time.time() - start - feat_time
print("finished KNN evaluation, feature time: {}, knn time: {}".format(
timedelta(seconds=feat_time), timedelta(seconds=knn_time)))
print(' * NN Acc@1 {:.3f}'.format(NN_acc))
return NN_acc