Pytorch实现CenterLoss(实战)

下面完整代码在github仓库:传送门


一、定义Center loss函数

import torch
import torch.nn as nn

def center_loss(feature, label, lambdas):
    center = nn.Parameter(torch.randn(int(max(label).item() + 1), feature.shape[1]), requires_grad=True).cuda()
    # print(center.shape)  # torch.Size([2, 2])
    # print(label.shape)  # torch.Size([5])

    center_exp = center.index_select(dim=0, index=label.long())
    # print(center_exp.shape)  # torch.Size([5, 2])

    count = torch.histc(label, bins=int(max(label).item() + 1), min=0, max=int(max(label).item()))
    # print(count)  # tensor([3., 2.], device='cuda:0')

    count_exp = count.index_select(dim=0, index=label.long())
    # print(count_exp)  # tensor([3., 3., 2., 3., 2.], device='cuda:0')

    loss = lambdas / 2 * torch.mean(torch.div(torch.sum(torch.pow(feature - center_exp, 2), dim=1), count_exp))
    return loss

if __name__ == '__main__':
    data = torch.tensor([[3, 4], [5, 6], [7, 8], [9, 8], [6, 5]], dtype=torch.float32).cuda()
    label = torch.tensor([0, 0, 1, 0, 1], dtype=torch.float32).cuda()
    loss = center_loss(data, label, 2)
    print(loss)

二、搭建网络模型

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 可以使用mobilenet-v2
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1, 2),  # 28*28
            nn.BatchNorm2d(32),
            nn.PReLU(),
            nn.Conv2d(32, 32, 5, 1, 2),  # 28*28
            nn.BatchNorm2d(32),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),  # 14*14

            nn.Conv2d(32, 64, 5, 1, 2),  # 14*14
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(64, 64, 5, 1, 2),  # 14*14
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.MaxPool2d(2, 2),  # 7*7

            nn.Conv2d(64, 128, 5, 1, 2),  # 7*7
            nn.BatchNorm2d(128),
            nn.PReLU(),
            nn.Conv2d(128, 128, 5, 1, 2),  # 7*7
            nn.BatchNorm2d(128),
            nn.PReLU(),
            nn.MaxPool2d(2, 2)  # 3*3

        )
        self.feature = nn.Linear(128*3*3, 2)
        self.output = nn.Linear(2, 10)

    def forward(self, x):
        y_conv = self.conv_layer(x)
        y_conv = torch.reshape(y_conv, [-1, 128*3*3])
        y_feature = self.feature(y_conv)  # N,2
        y_output = torch.log_softmax(self.output(y_feature), dim=1)  # N,10

        return y_feature, y_output

    def visualize(self, feat, labels, epoch):
        # plt.ion()
        color = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
             '#ff00ff', '#990000', '#999900', '#009900', '#009999']
        # plt.clf()
        for i in range(10):
            plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=color[i])
            # 将60000个特征点分到10个类里面,并画在坐标轴上

        plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
        # plt.xlim(xmin=-5, xmax=5)
        # plt.ylim(ymin=-5, ymax=5)
        plt.title("epoch=%d" % epoch)
        plt.savefig('./images/epoch=%d.jpg' % epoch)
        # plt.draw()
        # plt.pause(0.001)

'''
softmax + CELoss = softmax loss
log_softmax + NLLLoss = softmax loss
'''

三、开始训练数据

import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import torch.optim.lr_scheduler as lr_scheduler
from Net_Model import Net
from center_loss import center_loss
import os
import numpy as np

if __name__ == '__main__':
    save_path = "models/net_center.pth"
    transforms = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ]
    )
    train_data = torchvision.datasets.MNIST(root="./MNIST", download=True, train=True,
                                            transform=transforms)
    test_data = torchvision.datasets.MNIST(root="./MNIST", download=True, train=False,
                                           transform=transforms)
    train_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=512,
                                   num_workers=2)
    test_loader = data.DataLoader(dataset=test_data, shuffle=True, batch_size=256,
                                  num_workers=2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = Net().to(device)

    if os.path.exists(save_path):
        net.load_state_dict(torch.load(save_path))
    else:
        print("No Param")

    'CrossEntropyLoss()=torch.log(torch.softmax(None))+nn.NLLLoss()'
    'CrossEntropyLoss()=log_softmax() + NLLLoss() '
    'nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合'

    # loss_fn = nn.CrossEntropyLoss()
    loss_fn = nn.NLLLoss()
    # optimizer = torch.optim.Adam(net.parameters())
    optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, momentum=0)  # 前面10轮动量0.9,中间十轮动量0.3, 后面十轮动量为0
    # optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)
    # optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.0005)

    for epoch in range(100000):
        feat_loader = []
        label_loader = []
        for i, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)
            feature, output = net.forward(x)

            # print(feature.shape)  # torch.Size([100, 2])
            # print(output.shape)  # torch.Size([100, 10])

            loss_cls = loss_fn(output, y)  # output已经用log_softmax输出, 损失函数为NLLLoss
            y = y.float()

            loss_center = center_loss(feature, y, 0.5)  # 比重2可以给小一些,比如0.5

            loss = loss_cls + loss_center  # CELoss(相当于softmax_loss) + Center loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # print(y.shape)  # torch.Size([100])
            feat_loader.append(feature)
            label_loader.append(y)

            if i % 10 == 0:
                print("epoch:", epoch, "i:", i, "total_loss:", loss.item(),
                      "Softmax_loss", loss_cls.item(), "center_loss", loss_center.item())

        feat = torch.cat(feat_loader, 0)
        labels = torch.cat(label_loader, 0)
        # print(feat)
        # print(labels)
        # print(feat.shape)  # torch.Size([60000, 2])
        # print(labels.shape)  # torch.Size([60000])
        net.visualize(feat.data.cpu().numpy(), labels.data.cpu().numpy(), epoch)
        torch.save(net.state_dict(), save_path)

        eval_loss_cls = 0
        eval_acc_cls = 0
        for i, (x, y) in enumerate(test_loader):
            x = x.to(device)
            y = y.to(device)

            feature, output = net.forward(x)

            loss_cls = loss_fn(output, y)
            y_f = y.float()
            loss_center = center_loss(feature, y_f, 2)

            loss = loss_cls + loss_center

            eval_loss_cls += loss_cls.item() * y.size(0)
            out_argmax = torch.argmax(output, 1)
            eval_acc_cls += (out_argmax == y).sum().item()

        mean_loss_cls = eval_loss_cls / len(test_data)
        mean_acc_cls = eval_acc_cls / len(test_data)
        print("分类平均损失:{} 分类平均精度{}".format(mean_loss_cls, mean_acc_cls))

        # 分类问题用精度判断,
        # 1.训练完以后,改进网络模型、用不同的优化器去优化。(把centerloss写成一个类。中心点是可训练的。)
        # 2.SGD学习率可以改为0.5
  • 4
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值