NT-Xent Loss 代码纪录pytorch

对比学习损失函数 NT-Xent Loss 代码纪录

写法一

import torch
import torch.nn as nn

class MemoryBankModule(torch.nn.Module):
    """内存库实现 这是由轻Python包实现的所有损失函数的父类。这样,如果需要,任何丢失都可以与记忆库一起使用。
    size:内存库可以存储的键数。如果设置为0,表示不使用内存库。
    """

    def __init__(self, size: int = 2 ** 16):

        super(MemoryBankModule, self).__init__()

        if size < 0:
            msg = f'Illegal memory bank size {size}, must be non-negative.'
            raise ValueError(msg)

        self.size = size
        self.bank = None
        self.bank_ptr = None

    @torch.no_grad()
    def _init_memory_bank(self, dim: int):
        """如果内存库为空,则初始化内存库
        dim:其中的维数存储在库中
        """
        # 创建内存库
        # 我们可以像在moco repo https://github.com/facebookresearch/moco中那样使用寄存器缓冲区,但我们不想污染我们的检查点(checkpoints)
        self.bank = torch.randn(dim, self.size)
        self.bank = torch.nn.functional.normalize(self.bank, dim=0)
        self.bank_ptr = torch.LongTensor([0])

    @torch.no_grad()
    def _dequeue_and_enqueue(self, batch: torch.Tensor):
        #离开队列和进入队列
        """删除最老的批,添加最新的批
        batch:The latest batch of keys to add to the memory bank. 要添加到内存库的最新一批keys。
        """
        batch_size = batch.shape[0]
        ptr = int(self.bank_ptr)

        if ptr + batch_size >= self.size:
            self.bank[:, ptr:] = batch[:self.size - ptr].T.detach()
            self.bank_ptr[0] = 0
        else:
            self.bank[:, ptr:ptr + batch_size] = batch.T.detach()
            self.bank_ptr[0] = ptr + batch_size

    def forward(self,
                output: torch.Tensor,
                labels: torch.Tensor = None,
                update: bool = False):
        """
        查询内存库是否有 额外的 新增的? 负样本  Query memory bank for additional negative samples
        output:
            The output of the model.
        labels:
            Should always be None, will be ignored.

        Returns:The output if the memory bank is of size 0, otherwise the output and the entries from the memory bank.
        如果内存库的大小为0,则输出,否则输出和内存库中的条目。
        """

        # no memory bank, return the output 没有内存库,返回输出
        if self.size == 0:
            return output, None

        _, dim = output.shape

        # initialize the memory bank if it is not already done 如果还没有初始化内存库的话
        if self.bank is None:
            self._init_memory_bank(dim)

        # query and update memory bank #查询和更新内存库
        bank = self.bank.clone().detach()

        # only update memory bank if we later do backward pass (gradient) 只更新记忆库,如果我们以后做反向传递(梯度)
        if update:
            self._dequeue_and_enqueue(output)

        return output, bank


class NTXentLoss(MemoryBankModule):

    def __init__(self,
                 temperature: float = 0.5,  # 控制 logit 的缩放因子,影响 softmax 的输出  MoCo中温度系数的作用就是控制模型对负样本的区分度?
                 memory_bank_size: int = 0, # 内存库的大小,用于存储负样本向量。
                 gather_distributed: bool = False): # 是否进行分布式数据收集
        super(NTXentLoss, self).__init__(size=memory_bank_size)
        self.temperature = temperature
        self.gather_distributed = gather_distributed
        self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
        self.eps = 1e-8

        if abs(self.temperature) < self.eps:
            raise ValueError('Illegal temperature: abs({}) < 1e-8'
                             .format(self.temperature))

        '''1.输入参数: out0,out1: 第一个/第二个视图的输出特征向量,形状都为 (batch_size, embedding_size)'''
    def forward(self,
                out0: torch.Tensor,
                out1: torch.Tensor
                ):


        device = out0.device
        batch_size, _ = out0.shape

        # normalize the output to length 1 将输出归一化为长度1 确保out0 中每个样本的特征向量在 dim=1 维度上的长度为1,以此来实现特征向量的标准化
        '''2.处理输出: 进行单位长度归一化,即将它们转换为单位向量,这一步是为了将余弦相似度应用于它们'''
        out0 = nn.functional.normalize(out0, dim=1)
        out1 = nn.functional.normalize(out1, dim=1)

        # 向内存库请求负样本,如果out1需要梯度,则用out1扩展它,否则在内存库中保持相同的向量?(这允许保持内存库恒定,例如用于评估测试集上的损失)
        '''3.处理负样本:调用 super(NTXentLoss, self).forward(out1, update=out0.requires_grad) 从内存库中获取负样本,
        这里的 update=out0.requires_grad 意味着如果 out0 需要梯度,就将 out1 加入到内存库中。
        # out1: shape: (batch_size, embedding_size)
        # negatives: shape: (embedding_size, memory_bank_size)
        '''
        out1, negatives = super(NTXentLoss, self).forward(out1, update=out0.requires_grad)

        '''4.计算相似度:如果存在负样本 (negatives is not None),则使用内存库中的负样本计算余弦相似度:
        sim_pos: 计算正样本对的余弦相似度
        sim_neg: 计算每个样本与所有负样本的余弦相似度。
        将 sim_pos 和 sim_neg 组合成一个 logits 张量,用于后续的交叉熵计算。
        labels 被设置为全零,表示将正样本对 sim_pos 最大化
        '''
        #我们使用余弦相似度,这是一个点积(einsum),因为所有向量都已经归一化为单位长度。
        #以einsum表示:n = batch_size, c = embedding_size, K = memory_bank_size。

        if negatives is not None:
            # use negatives from memory bank
            negatives = negatives.to(device)


            '''sim_pos的形状为(batch_size, 1), sim_pos [i]表示批中第i个样本与其正对的相似度'''
            '''首先计算正样本损失l_pos, 大小为(N, 1)'''
            sim_pos = torch.einsum('nc,nc->n', out0, out1).unsqueeze(-1)  # [out0, out1]--->[q, k] ??

            '''再计算负样本损失l_neg, 大小为(N, K)'''
            '''sim_neg的形状为(batch_size, memory_bank_size), sim_neg [i,j]表示第i个样本与第j个负样本的相似度  '''
            sim_neg = torch.einsum('nc,ck->nk', out0, negatives)


            '''将标签设置为第一个“类”,即sim_pos,使其相对于sim_neg最大化'''
            '''将l_pos和l_neg进行cat操作,并除以温度参数temperature(控制concentration level of distribution),得到logits,大小为(N, 1+K)'''
            logits = torch.cat([sim_pos, sim_neg], dim=1) / self.temperature
            labels = torch.zeros(logits.shape[0], device=device, dtype=torch.long)


        else:
            '''如果没有负样本,使用单个进程计算两个视图之间的余弦相似度,并形成 logits 张量'''
            '''创建相应的 labels 张量,用于计算交叉熵'''
            # single process
            out0_large = out0
            out1_large = out1
            diag_mask = torch.eye(batch_size, device=out0.device, dtype=torch.bool)

            '''计算相似度,这里n = batch_size和m = batch_size * world_size. 结果向量的形状为(n, m)'''
            # here n = batch_size and m = batch_size * world_size
            # the resulting vectors have shape (n, m)
            logits_00 = torch.einsum('nc,mc->nm', out0, out0_large) / self.temperature
            logits_01 = torch.einsum('nc,mc->nm', out0, out1_large) / self.temperature
            logits_10 = torch.einsum('nc,mc->nm', out1, out0_large) / self.temperature
            logits_11 = torch.einsum('nc,mc->nm', out1, out1_large) / self.temperature

            '''删除同一图像的相同视图之间的相似性'''
            # remove simliarities between same views of the same image
            logits_00 = logits_00[~diag_mask].view(batch_size, -1)
            logits_11 = logits_11[~diag_mask].view(batch_size, -1)

            # concatenate logits
            # the logits tensor in the end has shape (2*n, 2*m-1) 最后的logits张量的形状为(2*n, 2*m-1)
            logits_0100 = torch.cat([logits_01, logits_00], dim=1)
            logits_1011 = torch.cat([logits_10, logits_11], dim=1)
            logits = torch.cat([logits_0100, logits_1011], dim=0)

            # create labels
            labels = torch.arange(batch_size, device=device, dtype=torch.long)
            labels = labels
            labels = labels.repeat(2)

        '''使用交叉熵损失函数计算 logits 和 labels 之间的交叉熵损失'''
        loss = self.cross_entropy(logits, labels)

        return loss

写法二

class NT_Xent(nn.Module):
    def __init__(self, temperature=0.07):
        super(NT_Xent, self).__init__()
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size, world_size):
        N = 2 * batch_size * world_size
        mask = torch.ones((N, N), dtype=torch.bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size * world_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def forward(self, zz):
        """
        We do not sample negative examples explicitly.
        Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
        """
        z_i, z_j = zz[:, 0], zz[:, 1]
        batch_size = z_i.shape[0]
        world_size = dist.get_world_size()
        N = 2 * batch_size * world_size

        mask = self.mask_correlated_samples(batch_size, world_size)

        z = torch.cat((z_i, z_j), dim=0)
        if world_size > 1:
            z = torch.cat(GatherLayer.apply(z), dim=0)

        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature

        sim_i_j = torch.diag(sim, batch_size * world_size)

        sim_j_i = torch.diag(sim, -batch_size * world_size)

        # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(
            N, 1
        )
        negative_samples = sim[mask].reshape(N, -1)
        labels = torch.zeros(N, device=positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss /= N
        return loss
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大西瓜的科研日记

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值