人脸识别辅助损失函数:centerloss

人脸识别有三个步骤:

1. 人脸检测:face detection,网络将图片上的人脸框出来,网络可采用MTCNN、YOLO等。
2. 特征提取:feature extraction,将获取的人脸框传入特征提取网络提取人脸特征,网络可采用ResNet, MobilNet等。
3. 对比识别:face recognition,创建人脸特征库,将每个人脸标签和人脸特征作为一组特征值保存到人脸特征库。需要识别的人脸经过特征提取网络提取到特征后,与人脸特征库里面的特征比对,相似度大于一定阈值的就是同一个人。

特征提取时,一般要用softmax进行分类。但用原始的softmax提取到的特征类内距较大,类间距较小,因此在与人脸库中的不同人脸特征对比的时候,相似度差距较小,准确率难如人意。因此,有必要对softmax进行改造,改造的方法很多,加上centerloss就是其中之一。
论文链接:http://ydwen.github.io/papers/WenECCV16.pdf

softmaxloss公式如下:

在这里插入图片描述
其中m代表每批次取样本的数量,n代表一共有多少人(多少类),权重w相当于n个向量组成的矩阵。在训练的时候,提取到的人脸特征x根据标签确定属于哪一类,然后和w中的对应向量做内积。如果特征向量x与权重向量 w i w_i wi内积很大,代表两个向量相似度很高,通过softmax输出此人脸特征属于这个类的概率也会很大。

centerloss公式如下:

在这里插入图片描述
centerloss给每个类设定了一个中心点,在训练的时候,将每个提取到的特征向量与对应类中心点的L2范数的平方作为损失,损失越大,说明特征向量距离对应类的中心点越远。降低此损失,每个类的人脸特征会距离中心点更近,也就是类内距变小。

中心点更新公式如下:

在这里插入图片描述
在训练的时候,每个类的中心点也要更新,使中心点和提取到的特征相互适应。更新的学习率一般设定为0.5

总的公式如下:

在这里插入图片描述
Ls为softmaxloss,主要作用为分类,Lc为centerloss,主要作用为减少类内距。两者共同作用,使每个类的人脸特征提取的更好。
其中的 λ \lambda λ平衡centerloss与softmaxloss的比重。

centerloss的缺点:

  1. 类别较多时,对硬件要求较高:类别越多中心点越多,计算量越大,对内存,GPU等要求越高
  2. L2范数的离群点对损失的影响较大,离群点难以回归到中心点
  3. 类内距难以训练到足够小
  4. 只适合同类样本差异不大的数据,若样本本身同一类差距就很大则难以区分

代码:

import torch
import torch.nn as nn
import torch.optim as optim


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1),  # 26
            nn.BatchNorm2d(16),
            nn.PReLU(),
            nn.Conv2d(16, 32, 3, 1),  # 24
            nn.BatchNorm2d(32),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),  # 12
            nn.Conv2d(32, 64, 3, 1),  # 10
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(64, 64, 3, 1),  # 8
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(64, 128, 3, 1),  # 6
            nn.BatchNorm2d(128),
            nn.PReLU(),
            nn.Conv2d(128, 128, 3, 1),  # 4
            nn.BatchNorm2d(128),
            nn.PReLU(),
            nn.MaxPool2d(2, 2)  # 2
        )
        self.layer2 = nn.Linear(2*2*128, 2)
        self.layer3 = nn.Linear(2, 10)
        self.center = Centerloss()

    def forward(self, x):
        x = self.layer1(x)
        x = x.view(x.size(0), -1)
        y_feature = self.layer2(x)  # N, 2
        y_output = self.layer3(y_feature)

        return y_feature, y_output


class Centerloss(nn.Module):
    def __init__(self, cls_num=10, feature_dim=2):
        super().__init__()
        self.center = nn.Parameter(torch.randn(cls_num, feature_dim), requires_grad=True)

    def forward(self, x, label, lambdas):
        center_exp = self.center[label]
        count = torch.histc(label, bins=int(max(label) + 1), max=int(max(label)), min=min(label))
        count_exp = count[label]
        a = torch.pow((x - center_exp), 2)  # N 10
        b = torch.sum(a, dim=1)  # N
        c = torch.div(b, count_exp.float())  # N
        d = torch.div(torch.sum(c), 10)  # N
        loss = 1/2 * lambdas * d
        return loss
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as trans
from center.net import Net
import os
import matplotlib.pyplot as plt


class Trainer:
    def __init__(self, save_path):
        self.save_path = save_path
        self.transform = trans.Compose([
            trans.ToTensor(), trans.Normalize([0.5,], [0.5])])
        self.train_data = torchvision.datasets.MNIST(root="MINIST", download=True, transform=self.transform)
        self.train_loader = data.DataLoader(self.train_data, batch_size=500, shuffle=True)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = Net().to(self.device)
        self.loss_softmax = nn.CrossEntropyLoss()
        self.optim_center = torch.optim.SGD(self.net.center.parameters(), lr=0.5, momentum=0.9)
        self.optim_softmax = torch.optim.Adam(self.net.parameters())

    def train(self):
        if os.path.exists(self.save_path):
            # self.net.load_state_dict(torch.load(self.save_path, map_location="cpu"))
            self.net.load_state_dict(torch.load(self.save_path))
        else:
            print("NO Param ")

        epoch = 0
        while True:
            feat_loader = []
            label_loader = []
            for i, (x, y) in enumerate(self.train_loader):
                x = x.to(self.device)
                y = y.to(self.device)

                feature, output = self.net(x)
                loss_center = self.net.center(feature, y, 2)
                loss_softmax = self.loss_softmax(output, y.long())
                loss = loss_softmax + loss_center

                self.optim_softmax.zero_grad()
                loss_center.backward(retain_graph=True)
                loss_softmax.backward()
                self.optim_softmax.step()
                self.optim_center.step()

                feat_loader.append(feature)
                label_loader.append(y)

                if i % 10 == 0:
                    print(epoch, loss.item(), loss_center.item(), loss_softmax.item())
            feat = torch.cat(feat_loader, 0)
            labels = torch.cat(label_loader, 0)

            self.visualize(feat.data.cpu().numpy(), labels.data.cpu().numpy(), epoch)
            epoch += 1
            torch.save(self.net.state_dict(), self.save_path)

            if epoch == 150:
                break

    def visualize(self, feat, labels, epoch):
        plt.ion()
        color = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
                 '#ff00ff', '#990000', '#999900', '#009900', '#009999']
        plt.clf()
        for i in range(10):
            plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=color[i])
            plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
            # plt.xlim(xmin=-5,xmax=5)
            # plt.ylim(ymin=-5,ymax=5)
            plt.title("epoch=%d" % epoch)
            plt.savefig('./images/epoch=%d.jpg' % epoch)
            # plt.draw()
            # plt.pause(0.001)


if __name__ == '__main__':

    train = Trainer("models/net.pth")
    train.train()

可视化结果:

softmax:
在这里插入图片描述
softmax + centerloss:

在这里插入图片描述

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值