对比学习损失函数 NT-Xent Loss 代码纪录
写法一
import torch
import torch.nn as nn
class MemoryBankModule(torch.nn.Module):
"""内存库实现 这是由轻Python包实现的所有损失函数的父类。这样,如果需要,任何丢失都可以与记忆库一起使用。
size:内存库可以存储的键数。如果设置为0,表示不使用内存库。
"""
def __init__(self, size: int = 2 ** 16):
super(MemoryBankModule, self).__init__()
if size < 0:
msg = f'Illegal memory bank size {size}, must be non-negative.'
raise ValueError(msg)
self.size = size
self.bank = None
self.bank_ptr = None
@torch.no_grad()
def _init_memory_bank(self, dim: int):
"""如果内存库为空,则初始化内存库
dim:其中的维数存储在库中
"""
# 创建内存库
# 我们可以像在moco repo https://github.com/facebookresearch/moco中那样使用寄存器缓冲区,但我们不想污染我们的检查点(checkpoints)
self.bank = torch.randn(dim, self.size)
self.bank = torch.nn.functional.normalize(self.bank, dim=0)
self.bank_ptr = torch.LongTensor([0])
@torch.no_grad()
def _dequeue_and_enqueue(self, batch: torch.Tensor):
#离开队列和进入队列
"""删除最老的批,添加最新的批
batch:The latest batch of keys to add to the memory bank. 要添加到内存库的最新一批keys。
"""
batch_size = batch.shape[0]
ptr = int(self.bank_ptr)
if ptr + batch_size >= self.size:
self.bank[:, ptr:] = batch[:self.size - ptr].T.detach()
self.bank_ptr[0] = 0
else:
self.bank[:, ptr:ptr + batch_size] = batch.T.detach()
self.bank_ptr[0] = ptr + batch_size
def forward(self,
output: torch.Tensor,
labels: torch.Tensor = None,
update: bool = False):
"""
查询内存库是否有 额外的 新增的? 负样本 Query memory bank for additional negative samples
output:
The output of the model.
labels:
Should always be None, will be ignored.
Returns:The output if the memory bank is of size 0, otherwise the output and the entries from the memory bank.
如果内存库的大小为0,则输出,否则输出和内存库中的条目。
"""
# no memory bank, return the output 没有内存库,返回输出
if self.size == 0:
return output, None
_, dim = output.shape
# initialize the memory bank if it is not already done 如果还没有初始化内存库的话
if self.bank is None:
self._init_memory_bank(dim)
# query and update memory bank #查询和更新内存库
bank = self.bank.clone().detach()
# only update memory bank if we later do backward pass (gradient) 只更新记忆库,如果我们以后做反向传递(梯度)
if update:
self._dequeue_and_enqueue(output)
return output, bank
class NTXentLoss(MemoryBankModule):
def __init__(self,
temperature: float = 0.5, # 控制 logit 的缩放因子,影响 softmax 的输出 MoCo中温度系数的作用就是控制模型对负样本的区分度?
memory_bank_size: int = 0, # 内存库的大小,用于存储负样本向量。
gather_distributed: bool = False): # 是否进行分布式数据收集
super(NTXentLoss, self).__init__(size=memory_bank_size)
self.temperature = temperature
self.gather_distributed = gather_distributed
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
self.eps = 1e-8
if abs(self.temperature) < self.eps:
raise ValueError('Illegal temperature: abs({}) < 1e-8'
.format(self.temperature))
'''1.输入参数: out0,out1: 第一个/第二个视图的输出特征向量,形状都为 (batch_size, embedding_size)'''
def forward(self,
out0: torch.Tensor,
out1: torch.Tensor
):
device = out0.device
batch_size, _ = out0.shape
# normalize the output to length 1 将输出归一化为长度1 确保out0 中每个样本的特征向量在 dim=1 维度上的长度为1,以此来实现特征向量的标准化
'''2.处理输出: 进行单位长度归一化,即将它们转换为单位向量,这一步是为了将余弦相似度应用于它们'''
out0 = nn.functional.normalize(out0, dim=1)
out1 = nn.functional.normalize(out1, dim=1)
# 向内存库请求负样本,如果out1需要梯度,则用out1扩展它,否则在内存库中保持相同的向量?(这允许保持内存库恒定,例如用于评估测试集上的损失)
'''3.处理负样本:调用 super(NTXentLoss, self).forward(out1, update=out0.requires_grad) 从内存库中获取负样本,
这里的 update=out0.requires_grad 意味着如果 out0 需要梯度,就将 out1 加入到内存库中。
# out1: shape: (batch_size, embedding_size)
# negatives: shape: (embedding_size, memory_bank_size)
'''
out1, negatives = super(NTXentLoss, self).forward(out1, update=out0.requires_grad)
'''4.计算相似度:如果存在负样本 (negatives is not None),则使用内存库中的负样本计算余弦相似度:
sim_pos: 计算正样本对的余弦相似度
sim_neg: 计算每个样本与所有负样本的余弦相似度。
将 sim_pos 和 sim_neg 组合成一个 logits 张量,用于后续的交叉熵计算。
labels 被设置为全零,表示将正样本对 sim_pos 最大化
'''
#我们使用余弦相似度,这是一个点积(einsum),因为所有向量都已经归一化为单位长度。
#以einsum表示:n = batch_size, c = embedding_size, K = memory_bank_size。
if negatives is not None:
# use negatives from memory bank
negatives = negatives.to(device)
'''sim_pos的形状为(batch_size, 1), sim_pos [i]表示批中第i个样本与其正对的相似度'''
'''首先计算正样本损失l_pos, 大小为(N, 1)'''
sim_pos = torch.einsum('nc,nc->n', out0, out1).unsqueeze(-1) # [out0, out1]--->[q, k] ??
'''再计算负样本损失l_neg, 大小为(N, K)'''
'''sim_neg的形状为(batch_size, memory_bank_size), sim_neg [i,j]表示第i个样本与第j个负样本的相似度 '''
sim_neg = torch.einsum('nc,ck->nk', out0, negatives)
'''将标签设置为第一个“类”,即sim_pos,使其相对于sim_neg最大化'''
'''将l_pos和l_neg进行cat操作,并除以温度参数temperature(控制concentration level of distribution),得到logits,大小为(N, 1+K)'''
logits = torch.cat([sim_pos, sim_neg], dim=1) / self.temperature
labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)
else:
'''如果没有负样本,使用单个进程计算两个视图之间的余弦相似度,并形成 logits 张量'''
'''创建相应的 labels 张量,用于计算交叉熵'''
# single process
out0_large = out0
out1_large = out1
diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool)
'''计算相似度,这里n = batch_size和m = batch_size * world_size. 结果向量的形状为(n, m)'''
# here n = batch_size and m = batch_size * world_size
# the resulting vectors have shape (n, m)
logits_00 = torch.einsum('nc,mc->nm', out0, out0_large) / self.temperature
logits_01 = torch.einsum('nc,mc->nm', out0, out1_large) / self.temperature
logits_10 = torch.einsum('nc,mc->nm', out1, out0_large) / self.temperature
logits_11 = torch.einsum('nc,mc->nm', out1, out1_large) / self.temperature
'''删除同一图像的相同视图之间的相似性'''
# remove simliarities between same views of the same image
logits_00 = logits_00[~diag_mask].view(batch_size, -1)
logits_11 = logits_11[~diag_mask].view(batch_size, -1)
# concatenate logits
# the logits tensor in the end has shape (2*n, 2*m-1) 最后的logits张量的形状为(2*n, 2*m-1)
logits_0100 = torch.cat([logits_01, logits_00], dim=1)
logits_1011 = torch.cat([logits_10, logits_11], dim=1)
logits = torch.cat([logits_0100, logits_1011], dim=0)
# create labels
labels = torch.arange(batch_size, device=device, dtype=torch.long)
labels = labels
labels = labels.repeat(2)
'''使用交叉熵损失函数计算 logits 和 labels 之间的交叉熵损失'''
loss = self.cross_entropy(logits, labels)
return loss
写法二
class NT_Xent(nn.Module):
def __init__(self, temperature=0.07):
super(NT_Xent, self).__init__()
self.temperature = temperature
self.criterion = nn.CrossEntropyLoss(reduction="sum")
self.similarity_f = nn.CosineSimilarity(dim=2)
def mask_correlated_samples(self, batch_size, world_size):
N = 2 * batch_size * world_size
mask = torch.ones((N, N), dtype=torch.bool)
mask = mask.fill_diagonal_(0)
for i in range(batch_size * world_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
return mask
def forward(self, zz):
"""
We do not sample negative examples explicitly.
Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
z_i, z_j = zz[:, 0], zz[:, 1]
batch_size = z_i.shape[0]
world_size = dist.get_world_size()
N = 2 * batch_size * world_size
mask = self.mask_correlated_samples(batch_size, world_size)
z = torch.cat((z_i, z_j), dim=0)
if world_size > 1:
z = torch.cat(GatherLayer.apply(z), dim=0)
sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
sim_i_j = torch.diag(sim, batch_size * world_size)
sim_j_i = torch.diag(sim, -batch_size * world_size)
# We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(
N, 1
)
negative_samples = sim[mask].reshape(N, -1)
labels = torch.zeros(N, device=positive_samples.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
loss = self.criterion(logits, labels)
loss /= N
return loss