利用Pytorch框架实现softmax回归的图像分类
1.导入基础包
import torch
from torch import nn
from torch.nn import init
import torchvision
import torchvision.transforms as transforms
import numpy as np
import sys
# sys.path.append('..')
# import d21zh_pytorch as d21
2.获取读取数据
def load_data_fashion_mnist(batch_size):
if sys.platform.startswith('win'):
num_workers=2#0表示不用额外进程来加速读取数据
else:
num_workers=4
mnist_train=torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download=True,
transform=transforms.ToTensor())
mnist_test=torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False,download=True,
transform=transforms.ToTensor())
train_iter=torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,
num_workers=num_workers)
test_iter=torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,
num_workers=num_workers)
return train_iter,test_iter
batch_size=256
train_iter,test_iter=load_data_fashion_mnist(batch_size)
3.定义和初始化模型
softmax回归的输出层是一个全连接层,所以用一个线性模块就可以了。每个batch样本x的形状为(batch_size,1,28,28),所以我们要先用view()将x的形状转换成(batch_size,784)才送入全连接层
num_inputs=784
num_outputs=10
class LinearNet(nn.Module):
def __init__(self,num_inputs,num_outputs):
super(LinearNet,self).__init__()
self.linear=nn.Linear(num_inputs,num_outputs)
def forward(self,x):#x.shape=[batch_size,1,28,28]
print(x.shape)
print(x.shape[0])#x.shape[0]获取x的第一个元素
y=self.linear(x.view(x.shape[0],-1))
print(y.shape)
return y
net=LinearNet(num_inputs,num_outputs)#定义线性模型,784个输入,10个输出
#将x的形状进行变换
class FlattenLayer(nn.Module):
def __init__(self):
super(FlattenLayer,self).__init__()
def forward(self,x):
return x.view(x.shape[0],-1)
# print(net)
# for x,y in train_iter:
# net.forward(x)
from collections import OrderedDict
net=nn.Sequential(
OrderedDict([
('flatten',FlattenLayer()),
('linear',nn.Linear(num_inputs,num_outputs))
])
)
# print(net[1])
# print(net.linear)
# print(net.flatten)
# for param in net.parameters():
# print(param.shape)
使用均值为0、标准差为0.01的正态分布随机初始化模型的权重参数
init.normal_(net.linear.weight,mean=0,std=0.01)
init.constant_(net.linear.bias,val=0)
4.softmax和交叉熵损失函数
#该函数同时包括softmax运算和交叉熵损失计算
loss=nn.CrossEntropyLoss()
5.定义优化算法
# for param in net.parameters():
# print(param.shape)
optimizer=torch.optim.SGD(net.parameters(),lr=0.1)
6.训练模型
num_epochs=5
def evaluate_accuracy(data_iter,net):
acc_sum,n=0.0,0
for x,y in data_iter:
# print('net(x): ',net(x).shape,'结束')#256行10列
acc_sum+=(net(x).argmax(dim=1)==y).float().sum().item()#.sum()不是.mean(),这是统计正确的个数,与上面不同
# print('acc_sum: ',acc_sum)
# print('y.shape[0]:',y.shape[0])#:y.shape[0]=256
n+=y.shape[0]
return acc_sum/n
# print(evaluate_accuracy(test_iter,net))
def train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):
for epoch in range(num_epochs):
train_1_sum,train_acc_sum,n=0.0,0.0,0
for x,y in train_iter:
# print(x.shape)#torch.Size([256, 1, 28, 28])
y_hat=net(x)
L=loss(y_hat,y).sum()
#梯度清零
if optimizer is not None:
optimizer.zero_grad()
elif params is not None and params[0].grad is not None:
for param in params:
param.grad.data.zero_()
L.backward()
if optimizer is None:
sgd(params,lr,batch_size)
else:
optimizer.step()#softmax回归的简介实现方式
train_1_sum+=L.item()
train_acc_sum+=(y_hat.argmax(dim=1)==y).sum().item()
n+=y.shape[0]
test_acc=evaluate_accuracy(test_iter,net)
print('epoch %d loss % .4f train_acc %.3f test_acc %.3f ' % (epoch+1,train_1_sum/n,train_acc_sum/n,test_acc))
train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,optimizer)
运行结果:
epoch 1 loss 0.0031 train_acc 0.750 test_acc 0.774
epoch 2 loss 0.0022 train_acc 0.815 test_acc 0.797
epoch 3 loss 0.0021 train_acc 0.824 test_acc 0.818
epoch 4 loss 0.0020 train_acc 0.833 test_acc 0.816
epoch 5 loss 0.0019 train_acc 0.836 test_acc 0.821
7.预测
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress',
'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[i] for i in labels]
#图片设置
def use_svg_display():
"""用矢量图显示svg"""
#在一行里面画出多张图像和对应标签的函数
def show_fashion_mnist(images,labels):
use_svg_display()
#'_'表示我们忽略,不使用的变量
_,figs=plt.subplots(1,len(images),figsize=(25,25))#
for f,img,lbl in zip(figs,images,labels):
f.imshow(img.view((28,28)).numpy())
f.set_title(lbl,color='white')
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
x,y=iter(test_iter).next()
true_labels=get_fashion_mnist_labels(y.numpy())
pred_labels=get_fashion_mnist_labels(net(x).argmax(dim=1).numpy())
titles=[true+'\n'+pred for true,pred in zip(true_labels,pred_labels)]
show_fashion_mnist(x[0:9],titles[0:9])
运行结果: