CenterLoss在Mnist数据集上的实现

centerloss,顾名思义,中心损失函数,它的原理主要是在softmax loss的基础上,通过对训练集的每个类别在特征空间分别维护一个类中心,在训练过程,增加样本经过网络映射后在特征空间与类中心的距离约束,从而兼顾了类内聚合与类间分离,centerloss只是一个辅助损失函数,softmaxloss才是主打,但softmaxloss只能简单的将类分开,还得加上centerloss这一个强力辅助才能保证特征之间不仅具有可分性,同时也具有可判别性。

我们都知道,对于分类来说,希望类内距小,类间距大,那centerloss+softmaxloss就有这种功能。

简单复习一下softmax函数:

关于softmax的这个函数,有一些基本特性:是归一化指数函数,本质是离散概率分布,常用于多分类,值域为[0,1],输出结果之和为1。

那接着就来看看softmaxloss这个损失函数:

其中Sj为sigmoid输出的值,yj为标签对应独热编码的值(0或者1)

因此softmaxloss可以化简为:

log函数大家都知道,是一个定义域为[0,+∞],值域在[-∞,∞]的增函数,那么softmaxloss定义域在[0,1],取log就是在[-∞,1],那么取-log整个函数最终就变成了定义域在[0,1],值域在[0,+∞]的减函数,并且过(1,0)这个点,这一点正好符合我们梯度下降(当概率为1,损失下降到0),因此我们就可以使用softmaxloss来一步步降低分类的损失。

关于cneterloss,可以先看看公式:

N表示mini-batch的大小,xi表示输出特征,C表示对应的i个类中心,因此centerloss就是希望一个batch中的每个样本的feature离feature 的中心的距离的平方和要越小越好,也就是类内距离要越小越好。
反向传播:

α是学习率,也就是步长,设置一般取值0.5。

这里有一个问题就是centerloss学习率取值为0.5,那如果用同一个优化器进行优化,必然会造成softmaxloss的梯度爆炸,导致整个模型崩溃。

因此这里我们想到用两个优化器进行优化,分别优化centerloss和softmaxloss,这一点可以在代码里看到。

两个损失函数共同作用,softmaxloss负责大致分开各数据,centerloss使类内距越来越小,各司其职,达到把特征区分到最佳效果。其中λ是一个超参数,表示训练时更加倾向于哪个的损失,我在训练时候,λ选择2。

下面看看训练的效果吧:

我只训练了39轮,其实还是能看出来效果还是挺好的。

大概提一下训练过程中的坑吧,因为这些中心点是随机的,有可能随机到的中心点不好,数据点久久不能分开,建议中止训练重新开始或者直接删除参数重新训练。还有就是λ的值对结果影响挺大的,小心调参。

代码:

import torch
import torch.nn as nn

class CLNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.center = nn.Parameter(torch.randn(100, 2), requires_grad=True)  # (10, 2)

    def forward(self, feature, label, lambdas=2):
        center_exp = self.center.index_select(dim=0, index=label.long())  # (100, 2)
        count = torch.histc(label, bins=int(max(label).item() + 1), min=int(min(label).item()), max=int(max(label).item()))  # (10,)
        count_exp = count.index_select(dim=0, index=label.long())  # (100,)
        loss = lambdas / 2 * torch.mean(torch.div(torch.sum(torch.pow(feature - center_exp, 2), dim=1), count_exp))
        return loss

import torch
from Net_Model import Net
from centerloss import CLNet
import torch.nn as nn
from torchvision import transforms, datasets
import os

class Trainer:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.s_net = Net().to(self.device)
        self.c_net = CLNet().to(self.device)
        self.s_save_path = "models/softmax_net.pth"
        self.c_save_path = "models/center_net.pth"
        self.nll_loss = nn.NLLLoss()
        self.s_optimizer = torch.optim.SGD(self.s_net.parameters(), lr=0.0005, momentum=0.9, weight_decay=0.0005)
        self.c_optimizer = torch.optim.SGD(self.c_net.parameters(), lr=0.5)
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.s_optimizer, gamma=0.95, last_epoch=-1)
        self.mean, self.std = self.mean_std()
        self.dataLoader = self.data_loader()

    def mean_std(self):
        sets = datasets.MNIST("./MNIST", train=True, download=False, transform=transforms.ToTensor())
        loader = torch.utils.data.DataLoader(sets, batch_size=len(sets), shuffle=True)
        data = next(iter(loader))[0]
        mean = round(torch.mean(data, dim=(0, 2, 3)).item(), 3)
        std = round(torch.std(data, dim=(0, 2, 3)).item(), 3)
        return mean, std

    def data_loader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((self.mean,), (self.std,))
        ])
        dataSet = datasets.MNIST("./MNIST", train=True, download=False, transform=transform)
        dataLoader = torch.utils.data.DataLoader(dataSet, batch_size=100, shuffle=True, num_workers=4)
        return dataLoader

    def train_test(self):
        if os.path.exists(self.s_save_path) and os.path.exists(self.c_save_path):
            self.s_net.load_state_dict(torch.load(self.s_save_path))
            self.c_net.load_state_dict(torch.load(self.c_save_path))
        else:
            print("NO Param")
        epoch = 0
        while True:
            feature_loader = []
            label_loader = []
            for i, (x, y) in enumerate(self.dataLoader):
                x = x.to(self.device)
                y = y.to(self.device)
                feature, output = self.s_net(x)
                nll_loss = self.nll_loss(output, y)

                y = y.float()
                center_loss = self.c_net(feature, y, 2)
                loss = nll_loss + center_loss

                self.s_optimizer.zero_grad()
                self.c_optimizer.zero_grad()
                loss.backward()
                self.s_optimizer.step()
                self.c_optimizer.step()

                feature_loader.append(feature)
                label_loader.append(y)
                if i % 100 == 0:
                    print("epoch:", epoch, "i:", i, "loss:", loss.item(), "softmax_loss:", nll_loss.item(),
                          "center_loss:", center_loss.item())
            features = torch.cat(feature_loader, dim=0)
            labels = torch.cat(label_loader, dim=0)
            self.s_net.visualize(features.data.cpu().numpy(), labels.data.cpu().numpy(), epoch)
            torch.save(self.s_net.state_dict(), self.s_save_path)
            torch.save(self.c_net.state_dict(), self.c_save_path)
            self.scheduler.step(None)
            epoch += 1

            if epoch == 40:
                break
if __name__ == '__main__':
    Trainer=Trainer()
    Trainer.train_test()
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值