f_k.params = f_q.params # 初始化
for x in loader: # 输入一个图像序列x,包含N张图,没有标签
x_q = aug(x) # 用于查询的图(数据增强得到)
x_k = aug(x) # 模板图(数据增强得到),自监督就体现在这里,只有图x和x的数据增强才被归为一类
q = f_q.forward(x_q) # 提取查询特征,输出NxC
k = f_k.forward(x_k) # 提取模板特征,输出NxC
# 不使用梯度更新f_k的参数,这是因为文章假设用于提取模板的表示应该是稳定的,不应立即更新
k = k.detach()
# 这里bmm是分批矩阵乘法
l_pos = bmm(q.view(N,1,C), k.view(N,C,1)) # 输出Nx1,也就是自己与自己的增强图的特征的匹配度
l_neg = mm(q.view(N,C), queue.view(C,K)) # 输出Nxk,自己与上一批次所有图的匹配度(全不匹配)
logits = cat([l_pos, l_neg], dim=1) # 输出Nx(1+k)
labels = zeros(N)
# NCE损失函数,就是为了保证自己与自己衍生的匹配度输出越大越好,否则越小越好
loss = CrossEntropyLoss(logits/t, labels)
loss.backward()
update(f_q.params) # f_q使用梯度立即更新
# 由于假设模板特征的表示方法是稳定的,因此它更新得更慢,这里使用动量法更新,相当于做了个滤波。
f_k.params = m*f_k.params+(1-m)*f_q.params
enqueue(queue, k) # 为了生成反例,所以引入了队列
dequeue(queue)
Moco伪代码
最新推荐文章于 2024-08-01 22:10:38 发布