'''
Created on Feb 7, 2021
@author: nakaizura
'''
#MoCo的主要就是以下操作:
#1 维护queue来动态更新。
#2 keys部分单独momentum以解耦batch size。
#3 一个trick:Shuffling BN
import torch
import torch.nn as nn
class MoCo(nn.Module):
"""
主要就是这个class来完成的逻辑
"""
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
"""
dim: 特征维度 (default: 128)
K: 负例序列长度 (default: 65536)
m: k部分的moco momentum更新率 (default: 0.999)
T: softmax平滑系数 (default: 0.07)
"""
super(MoCo, self).__init__()
self.K = K
self.m = m
self.T = T
# 创建解码器
# num_classes输出的类别数量
self.encoder_q = base_encoder(num_classes