下面完整代码在github仓库:传送门
一、定义Center loss函数
import torch
import torch.nn as nn
def center_loss(feature, label, lambdas):
center = nn.Parameter(torch.randn(int(max(label).item() + 1), feature.shape[1]), requires_grad=True).cuda()
# print(center.shape) # torch.Size([2, 2])
# print(label.shape) # torch.Size([5])
center_exp = center.index_select(dim=0, index=label.long())
# print(center_exp.shape) # torch.Size([5, 2])
count = torch.histc(label, bins=int(max(label).item() + 1), min=0, max=int(max(label).item()))
# print(count) # tensor([3., 2.], device='cuda:0')
count_exp = count.index_select(dim=0, index=label.long())
# print(count_exp) # tensor([3., 3., 2., 3., 2.], device='cuda:0')
loss = lambdas / 2 * torch.mean(torch.div(torch.sum(torch.pow(feature - center_exp, 2), dim=1), count_exp))
return loss
if __name__ == '__main__':
data = torch.tensor([[3, 4], [5, 6], [7, 8], [9, 8], [6, 5]], dtype=torch.float32).cuda()
label = torch.tensor([0, 0, 1, 0, 1], dtype=torch.float32).cuda()
loss = center_loss(data, label, 2)
print(loss)
二、搭建网络模型
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# 可以使用mobilenet-v2
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv_layer = nn.Sequential(
nn.Conv2d(1, 32, 5, 1, 2), # 28*28
nn.BatchNorm2d(32),
nn.PReLU(),
nn.Conv2d(32, 32, 5, 1, 2), # 28*28
nn.BatchNorm2d(32),
nn.PReLU(),
nn.MaxPool2d(2, 2), # 14*14
nn.Conv2d(32, 64, 5, 1, 2), # 14*14
nn.BatchNorm2d(64),
nn.PReLU(),
nn.Conv2d(64, 64, 5, 1, 2), # 14*14
nn.BatchNorm2d(64),
nn.PReLU(),
nn.MaxPool2d(2, 2), # 7*7
nn.Conv2d(64, 128, 5, 1, 2), # 7*7
nn.BatchNorm2d(128),
nn.PReLU(),
nn.Conv2d(128, 128, 5, 1, 2), # 7*7
nn.BatchNorm2d(128),
nn.PReLU(),
nn.MaxPool2d(2, 2) # 3*3
)
self.feature = nn.Linear(128*3*3, 2)
self.output = nn.Linear(2, 10)
def forward(self, x):
y_conv = self.conv_layer(x)
y_conv = torch.reshape(y_conv, [-1, 128*3*3])
y_feature = self.feature(y_conv) # N,2
y_output = torch.log_softmax(self.output(y_feature), dim=1) # N,10
return y_feature, y_output
def visualize(self, feat, labels, epoch):
# plt.ion()
color = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff',
'#ff00ff', '#990000', '#999900', '#009900', '#009999']
# plt.clf()
for i in range(10):
plt.plot(feat[labels == i, 0], feat[labels == i, 1], '.', c=color[i])
# 将60000个特征点分到10个类里面,并画在坐标轴上
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'], loc='upper right')
# plt.xlim(xmin=-5, xmax=5)
# plt.ylim(ymin=-5, ymax=5)
plt.title("epoch=%d" % epoch)
plt.savefig('./images/epoch=%d.jpg' % epoch)
# plt.draw()
# plt.pause(0.001)
'''
softmax + CELoss = softmax loss
log_softmax + NLLLoss = softmax loss
'''
三、开始训练数据
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import torch.optim.lr_scheduler as lr_scheduler
from Net_Model import Net
from center_loss import center_loss
import os
import numpy as np
if __name__ == '__main__':
save_path = "models/net_center.pth"
transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
]
)
train_data = torchvision.datasets.MNIST(root="./MNIST", download=True, train=True,
transform=transforms)
test_data = torchvision.datasets.MNIST(root="./MNIST", download=True, train=False,
transform=transforms)
train_loader = data.DataLoader(dataset=train_data, shuffle=True, batch_size=512,
num_workers=2)
test_loader = data.DataLoader(dataset=test_data, shuffle=True, batch_size=256,
num_workers=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Net().to(device)
if os.path.exists(save_path):
net.load_state_dict(torch.load(save_path))
else:
print("No Param")
'CrossEntropyLoss()=torch.log(torch.softmax(None))+nn.NLLLoss()'
'CrossEntropyLoss()=log_softmax() + NLLLoss() '
'nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合'
# loss_fn = nn.CrossEntropyLoss()
loss_fn = nn.NLLLoss()
# optimizer = torch.optim.Adam(net.parameters())
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, momentum=0) # 前面10轮动量0.9,中间十轮动量0.3, 后面十轮动量为0
# optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)
# optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.0005)
for epoch in range(100000):
feat_loader = []
label_loader = []
for i, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device)
feature, output = net.forward(x)
# print(feature.shape) # torch.Size([100, 2])
# print(output.shape) # torch.Size([100, 10])
loss_cls = loss_fn(output, y) # output已经用log_softmax输出, 损失函数为NLLLoss
y = y.float()
loss_center = center_loss(feature, y, 0.5) # 比重2可以给小一些,比如0.5
loss = loss_cls + loss_center # CELoss(相当于softmax_loss) + Center loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print(y.shape) # torch.Size([100])
feat_loader.append(feature)
label_loader.append(y)
if i % 10 == 0:
print("epoch:", epoch, "i:", i, "total_loss:", loss.item(),
"Softmax_loss", loss_cls.item(), "center_loss", loss_center.item())
feat = torch.cat(feat_loader, 0)
labels = torch.cat(label_loader, 0)
# print(feat)
# print(labels)
# print(feat.shape) # torch.Size([60000, 2])
# print(labels.shape) # torch.Size([60000])
net.visualize(feat.data.cpu().numpy(), labels.data.cpu().numpy(), epoch)
torch.save(net.state_dict(), save_path)
eval_loss_cls = 0
eval_acc_cls = 0
for i, (x, y) in enumerate(test_loader):
x = x.to(device)
y = y.to(device)
feature, output = net.forward(x)
loss_cls = loss_fn(output, y)
y_f = y.float()
loss_center = center_loss(feature, y_f, 2)
loss = loss_cls + loss_center
eval_loss_cls += loss_cls.item() * y.size(0)
out_argmax = torch.argmax(output, 1)
eval_acc_cls += (out_argmax == y).sum().item()
mean_loss_cls = eval_loss_cls / len(test_data)
mean_acc_cls = eval_acc_cls / len(test_data)
print("分类平均损失:{} 分类平均精度{}".format(mean_loss_cls, mean_acc_cls))
# 分类问题用精度判断,
# 1.训练完以后,改进网络模型、用不同的优化器去优化。(把centerloss写成一个类。中心点是可训练的。)
# 2.SGD学习率可以改为0.5