人脸识别 Center Loss及实现

Center Loss

为什么用Center Loss?

首先要声明的是Center Loss主要用于做人脸的识别,那么为什么不用softmax对人脸直接分类呢?因为人脸之间的特征是十分相似的,在类与类之间的交界处是很难区分开的,换句话说softmax分类两张人脸得到的概率值都是0.5左右,导致分类结果不准确。那么怎么把这种交界处区分开呢?有两种方法:1、第一种扩大类间距,2、第二种扩大类内距。CenterLoss就是采用的第二种方法。
下面是我用mnist数字十分类做的直接用softmax loss和softmax loss + center loss 做的效果图:
在这里插入图片描述
由于我做的效果好一些的图不小心删掉了,这个图大概能看出softmax loss损失的缺点,就是越靠近中心位置softmax是很难区分开的。
在这里插入图片描述
上面是Center Loss + softmax loss做出的效果图,很明显所有的类都清晰可分。

如何实现

在这里插入图片描述
如上图,直接看公式,一个是softmax loss 一个是center loss把他们加起来就是总损失。

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import torchvision
import torch.utils.data as data
import matplotlib.pyplot as plt
import torch.optim.lr_scheduler as lr_scheduler
import torchvision.transforms as transforms


class Centerloss(torch.nn.Module):
    def __init__(self, feature_num, cls_num):
        super(Centerloss, self).__init__()
        self.cls_num = cls_num
        self.center = nn.Parameter(torch.randn(cls_num, feature_num))

    def forward(self, xs, ys):
        #  xs = F.normalize(xs)
        #  把标签值当作中心点的中心点索引值 取出符合标签的所有中心点的坐标值
        center_exp = self.center.index_select(dim=0, index=ys.long())
        #  统计出标签0-cls_mun的每个值的个数
        count = torch.histc(ys, bins=self.cls_num, min=0, max=self.cls_num-1)
        #  把标签值当作count的索引值,取出符合标签的所有count的坐标值 
        count_dis = count.index_select(dim = 0, index=ys.long())

        return torch.sum(torch.sqrt(torch.sum((xs - center_exp) ** 2, dim=1)) / count_dis.float())



class ClsNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer = nn.Sequential(nn.Conv2d(1, 32, 3), nn.BatchNorm2d(32), nn.PReLU(),
                                        nn.Conv2d(32, 64, 3), nn.BatchNorm2d(64), nn.PReLU(),
                                        nn.MaxPool2d(3, 2))
        self.feature_layer = nn.Sequential(nn.Linear(11*11*64, 256), nn.BatchNorm1d(256), nn.PReLU(),
                                           nn.Linear(256, 128), nn.BatchNorm1d(128), nn.PReLU(),
                                           nn.Linear(128, 2), nn.PReLU())
        self.out_layer = nn.Sequential(nn.Linear(2, 10))
        self.loss_fn1 = Centerloss(2, 10)
        # self.loss_fn2 = nn.CrossEntropyLoss()
        self.loss_fn2 = nn.CrossEntropyLoss()


    def forward(self, x):
        conv = self.conv_layer(x)
        conv = conv.reshape(x.size(0), -1)
        self.feature = self.feature_layer(conv)
        self.out = self.out_layer(self.feature)
        return self.feature

    def get_loss(self, ys):

        loss1 = self.loss_fn1(self.feature, ys)
        loss2 = self.loss_fn2(self.out, ys.long())
        return loss1, loss2


if __name__ == '__main__':

    train_data = torchvision.datasets.MNIST(
    root='mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
    )
    test_data = torchvision.datasets.MNIST(
    root='mnist',
    train=False,
    transform = torchvision.transforms.ToTensor(),
    download=False
    )

    train = data.DataLoader(dataset=train_data, batch_size=1024, shuffle=True, drop_last= True)
    test = data.DataLoader(dataset=test_data, batch_size=1024, shuffle=True)
    # transform = transforms.Compose([
    #     transforms.Resize(28, 28),
    #     transforms.ToTensor(),
    #     transforms.Normalize((0.5,), (0.5,)),
    net = ClsNet().cuda()
    # net = net.to(device)
    path = r'params/weightnet2.pt'
    if os.path.exists(path):
        net.load_state_dict(torch.load(path))
        net.eval()
        print('load susseful')
    else:
        print('load fail')

    epoch = 1024

    # optimism = optim.SGD(net.parameters(), lr=1e-3)
    optimism = optim.Adam(net.parameters(), lr=0.0005)
    # scheduler = lr_scheduler.StepLR(optimism, 10, gamma=0.8)
    # optimizer = optim.SGD(net.parameters(), weight_decay=0.0005, lr=0.001, momentum=0.9)
    # scheduler = lr_scheduler.StepLR(optimizer, 20, gamma=0.8)
    # optimizercenter = optim.SGD(Centerloss.parameters(), lr=0.5)
    losses = []
    # In[]
    c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
     '#ff00ff', '#990000', '#999900', '#009900', '#009999']

    epoch = 10000
    d = 0
    # fig, ax = plt.subplots()

    for i in range(epoch):
        # scheduler.step()
        print('epoch: {}'.format(i))
        print(len(train))
        tar = []
        out = []
        for j, (input, target) in enumerate(train):
            input = input.cuda()
            target = target.cuda()
            output = net(input)

            loss1, loss2 = net.get_loss(target)
            loss = 0.1 * loss1 + loss2

            # label = torch.argmax(output, dim=1)  # 选出最大值的索引作为标签

            # 清空梯度 反向传播 更新梯度
            optimism.zero_grad()
            loss.backward()
            optimism.step()

            output = output.cpu().detach().numpy()
            # print(output)
            target = target.cpu().detach()
            print(target)
            out.extend(output)
            tar.extend(target)

            print('[epochs - {} - {} / {}] loss: {} loss1:{} loss2: {}'.format(
                i, j, len(train), loss.float(), loss1.float(), loss2.float()))
            outstack = np.stack(out)
            tarstack = torch.stack(tar)

            # plt.cla()
            plt.ion()
            if j == 3:
                d += 1
                for m in range(10):
                    index = torch.tensor(torch.nonzero(tarstack == m))
                    plt.scatter(outstack[:, 0][index[:, 0]], outstack[:, 1][index[:, 0]], c=c[m], marker='.')
                plt.show()
                plt.pause(10)

                plt.savefig('picture1.1/{0}.jpg'.format(d))
                print('save sussece')
            # plt.ioff()
            # plt.clf()
            plt.close()

        torch.save(net.state_dict(), r'params/weightnet2.pt')
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值