CVPR2022人脸识别Partial FC论文及代码学习笔记

论文链接:https://openaccess.thecvf.com/content/CVPR2022/papers/An_Killing_Two_Birds_With_One_Stone_Efficient_and_Robust_Training_CVPR_2022_paper.pdf

代码链接:insightface/recognition/arcface_torch at master · deepinsight/insightface · GitHub

背景

使用基于百万规模的数据集和基于margin的softmax损失函数来学习区分性的embeddings是当前人脸识别的SOTA方法。然而,全连接层的内存和计算成本随着训练集中ID数量的增加而线性增加。此外,大规模训练数据存在类间冲突(同一个人被分成不同ID)和长尾分布的问题。

传统FC

将传统的FC层应用在大规模的数据集上时,存在以下缺陷:

1、gradient confusion under interclass conflict

WebFace42M里有很多不同类别对之间的余弦相似度大于0.4,这表明类间冲突仍然存在于这些清洗过的数据集中。直接优化的话会导致gradient confusion(同一个人的特征非常相似却要掰成两个ID)

2、centers of tail classes undergo too many passive updates

每个iteration都优化图片数量很少的id,可能会导致负优化

3、the storage and calculation of the FC layer can easily exceed current GPU capabilities

PartialFC

在训练期间仍然维护所有类别中心,但只随机采样一小部分负类别中心来计算基于margin的softmax损失,而不是在每次迭代中使用所有负类别中心。更具体地说,首先从每个GPU收集embeddings和标签,然后将组合的特征和标签分布到所有GPU。为了平衡每个GPU的内存使用和计算成本,为每个GPU设置了一个内存缓冲区(下面代码中的perm,大小为self.sample_rate * self.num_local)。内存缓冲区的大小由类别总数和负类别中心的采样率决定。在每个GPU上,首先通过标签选择正类中心并放入缓冲区,然后随机选择一小部分负类中心(负类中心的数量为self.sample_rate * self.num_local-positive.shape[0])填充缓冲区的其余部分,

def sample(self, labels, index_positive):
    """
    This functions will change the value of labels
    Parameters:
    -----------
    labels: torch.Tensor
        pass
    index_positive: torch.Tensor
        pass
    optimizer: torch.optim.Optimizer
        pass
    """
    with torch.no_grad():
        positive = torch.unique(labels[index_positive], sorted=True).cuda()
        if self.num_sample - positive.size(0) >= 0:
            perm = torch.rand(size=[self.num_local]).cuda()
            perm[positive] = 2.0
            index = torch.topk(perm, k=self.num_sample)[1].cuda()
            index = index.sort()[0].cuda()
        else:
            index = positive
        self.weight_index = index

        labels[index_positive] = torch.searchsorted(index, labels[index_positive])

    return self.weight[self.weight_index]

随后,使用选出的样本中心去与特征相乘并计算基于margin的softmax损失。

PFC在DDP框架下的流程图如下图所示,

整体代码如下,

class PartialFC_V2(torch.nn.Module):
    """
    https://arxiv.org/abs/2203.15565
    A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
    When sample rate less than 1, in each iteration, positive class centers and a random subset of
    negative class centers are selected to compute the margin-based softmax loss, all class
    centers are still maintained throughout the whole training process, but only a subset is
    selected and updated in each iteration.
    .. note::
        When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
    Example:
    --------
    >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
    >>> for img, labels in data_loader:
    >>>     embeddings = net(img)
    >>>     loss = module_pfc(embeddings, labels)
    >>>     loss.backward()
    >>>     optimizer.step()
    """
    _version = 2

    def __init__(
        self,
        margin_loss: Callable,
        embedding_size: int,
        num_classes: int,
        sample_rate: float = 1.0,
        fp16: bool = False,
    ):
        """
        Paramenters:
        -----------
        embedding_size: int
            The dimension of embedding, required
        num_classes: int
            Total number of classes, required
        sample_rate: float
            The rate of negative centers participating in the calculation, default is 1.0.
        """
        super(PartialFC_V2, self).__init__()
        assert (
            distributed.is_initialized()
        ), "must initialize distributed before create this"
        self.rank = distributed.get_rank()
        self.world_size = distributed.get_world_size()

        self.dist_cross_entropy = DistCrossEntropy()
        self.embedding_size = embedding_size
        self.sample_rate: float = sample_rate
        self.fp16 = fp16
        self.num_local: int = num_classes // self.world_size + int(
            self.rank < num_classes % self.world_size
        )
        self.class_start: int = num_classes // self.world_size * self.rank + min(
            self.rank, num_classes % self.world_size
        )
        self.num_sample: int = int(self.sample_rate * self.num_local)
        self.last_batch_size: int = 0

        self.is_updated: bool = True
        self.init_weight_update: bool = True
        self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))

        # margin_loss
        if isinstance(margin_loss, Callable):
            self.margin_softmax = margin_loss
        else:
            raise

    def sample(self, labels, index_positive):
        """
            This functions will change the value of labels
            Parameters:
            -----------
            labels: torch.Tensor
                pass
            index_positive: torch.Tensor
                pass
            optimizer: torch.optim.Optimizer
                pass
        """
        with torch.no_grad():
            positive = torch.unique(labels[index_positive], sorted=True).cuda()
            if self.num_sample - positive.size(0) >= 0:
                perm = torch.rand(size=[self.num_local]).cuda()
                perm[positive] = 2.0
                index = torch.topk(perm, k=self.num_sample)[1].cuda()
                index = index.sort()[0].cuda()
            else:
                index = positive
            self.weight_index = index

            labels[index_positive] = torch.searchsorted(index, labels[index_positive])

        return self.weight[self.weight_index]

    def forward(
        self,
        local_embeddings: torch.Tensor,
        local_labels: torch.Tensor,
    ):
        """
        Parameters:
        ----------
        local_embeddings: torch.Tensor
            feature embeddings on each GPU(Rank).
        local_labels: torch.Tensor
            labels on each GPU(Rank).
        Returns:
        -------
        loss: torch.Tensor
            pass
        """
        local_labels.squeeze_()
        local_labels = local_labels.long()

        batch_size = local_embeddings.size(0)
        if self.last_batch_size == 0:
            self.last_batch_size = batch_size
        assert self.last_batch_size == batch_size, (
            f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")

        _gather_embeddings = [
            torch.zeros((batch_size, self.embedding_size)).cuda()
            for _ in range(self.world_size)
        ]
        _gather_labels = [
            torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
        ]
        _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
        distributed.all_gather(_gather_labels, local_labels)

        embeddings = torch.cat(_list_embeddings)
        labels = torch.cat(_gather_labels)
        
        ## 选出落在本进程对应的类别范围内的数据
        labels = labels.view(-1, 1)
        index_positive = (self.class_start <= labels) & (
            labels < self.class_start + self.num_local
        )
        ## 标签不在本类别段的, 将其类别标签设为-1
        labels[~index_positive] = -1
        ## 将类别ID平移到原点(因为不同进程都会初始化对应的self.weight, 若不平移回去, 则label与self.weight中的index会对应不上)
        labels[index_positive] -= self.class_start

        if self.sample_rate < 1:
            weight = self.sample(labels, index_positive)
        else:
            weight = self.weight

        with torch.cuda.amp.autocast(self.fp16):
            norm_embeddings = normalize(embeddings)
            norm_weight_activated = normalize(weight)
            logits = linear(norm_embeddings, norm_weight_activated)
        if self.fp16:
            logits = logits.float()
        logits = logits.clamp(-1, 1)

        logits = self.margin_softmax(logits, labels)
        loss = self.dist_cross_entropy(logits, labels)
        return loss

实验结果

将PFC替换掉传统FC后,模型在WebFace(包括4m、12m、42m)上的性能会有所提升,

 消融实验的结果如下,

与SOTA方法的性能对比如下, 

结论与讨论

结论

作者提出了一种用于在大规模数据集上训练人脸识别模型的方法——Partial FC (PFC)。在PFC的每次迭代中,仅选择一小部分类别中心来计算基于边际的softmax损失,这样可以显著减少类间冲突的概率、尾类中心的被动更新频率以及计算需求。通过广泛的实验,作者验证了所提出的PFC的有效性、鲁棒性和高效性。

局限性

尽管在WebFace上训练的PFC模型在高质量测试集上取得了不错的结果,但在人脸分辨率较低或低光照条件下拍摄的人脸上,PFC模型的表现可能较差。

### 部分自注意力机制概述 部分自注意力机制是一种改进版的自注意力模型,旨在减少计算复杂度并提高处理效率。传统自注意力机制会考虑序列中的每一个位置与其他所有位置之间的关系,而部分自注意力则只关注局部区域内的交互作用[^1]。 这种设计使得网络能够在保持性能的同时显著降低资源消耗,在大规模数据集上展现出更好的扩展性和实用性。具体来说,通过限制查询向量与键向量之间的作用范围,可以有效控制内存占用和运算时间。 ### 实现方式 为了实现部分自注意力机制,通常采用滑动窗口方法来限定每个token所能访问到的历史上下文长度。以下是基于PyTorch框架的一个简单示例: ```python import torch from torch import nn class PartialSelfAttention(nn.Module): def __init__(self, embed_size, heads=8, window_size=3): super(PartialSelfAttention, self).__init__() assert (embed_size % heads == 0), "Embedding size must be divisible by number of heads" self.embed_size = embed_size self.heads = heads self.window_size = window_size self.values = nn.Linear(self.embed_size, self.embed_size, bias=False) self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False) self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False) self.fc_out = nn.Linear(self.embed_size, self.embed_size) def forward(self, values, keys, query, mask=None): N = query.shape[0] value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # Split embedding into self.heads different pieces values = values.reshape(N, value_len, self.heads, self.embed_size // self.heads) keys = keys.reshape(N, key_len, self.heads, self.embed_size // self.heads) queries = query.reshape(N, query_len, self.heads, self.embed_size // self.heads) energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) if mask is not None: energy = energy.masked_fill(mask == 0, float("-1e20")) attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3) out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( N, query_len, self.embed_size ) out = self.fc_out(out) return out ``` 在这个例子中,`window_size`参数用于定义每个query能够看到的最大距离。当设置较小值时,即实现了所谓的“局部化”的自我注意功能;如果将其设为整个输入序列长度,则退化成标准形式[^2]。 ### 应用场景 部分自注意力机制广泛应用于自然语言处理领域之外的任务当中,特别是在通信系统优化方面表现出色。例如,在光通信干扰消除研究里提到利用深度学习技术解决非正交信号带来的多址接入问题时,就采用了类似的思路来构建更高效的编码解码器结构。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

chen_znn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值