多层全连接网络:实现手写数字识别50轮准确率92.1%
1 导入必备库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
print(torch.__version__)
输出:
1.12.1+cu102
2 torchvision内置了常用数据集和最常见的模型
import torchvision
from torchvision.transforms import ToTensor
''' transforms.ToTensor
1.转化为一个 tensor
2.转换到0-1之间
3.会将channel放在第一维度上
'''
train_ds = torchvision.datasets.MNIST('data/',
train=True,
transform=ToTensor(),
download=False
)
test_ds = torchvision.datasets.MNIST('data/',
train=False,
transform=ToTensor(),
download=False
)
print(len(train_ds),len(test_ds))
输出:
60000 10000
3 数据批量加载
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=256)
# iter方法创建生成器,next方法返回一个批次的图像,shape属性返回一批次张量形状
imgs, labels = next(iter(train_dl))
print(imgs.shape)
print(labels.shape)
输出:
torch.Size([64, 1, 28, 28])
torch.Size([64])
4 绘制样例
plt.figure(figsize=(10, 1))
for i, img in enumerate(imgs[:10]):
npimg = img.numpy()
npimg = np.squeeze(npimg)
plt.subplot(1, 10, i+1)
plt.imshow(npimg)
plt.xticks([])
plt.yticks([])
plt.xlabel(labels[i].numpy())
# plt.axis('off') #关闭显示坐标
plt.savefig('pics/3.1.jpg', dpi=400)
5 创建模型
class Model(nn.Module):
def __init__(self):
super().__init__()
self.liner_1 = nn.Linear(28*28, 120)
self.liner_2 = nn.Linear(120, 84)
self.liner_3 = nn.Linear(84, 10)
def forward(self, input):
x = input.view(-1, 28*28)
x = F.relu(self.liner_1(x))
x = F.relu(self.liner_2(x))
x = self.liner_3(x)
return x
7 设置是否使用GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
# 将模型移动到DEVICE
model = Model().to(device)
print(model)
输出:
Using cuda device
Model(
(liner_1): Linear(in_features=784, out_features=120, bias=True)
(liner_2): Linear(in_features=120, out_features=84, bias=True)
(liner_3): Linear(in_features=84, out_features=10, bias=True)
)
8 设置损失函数和优化器
loss_fn = torch.nn.CrossEntropyLoss() # 损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
9 定义训练函数
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
train_loss, correct = 0, 0
for X, y in dataloader:
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()
train_loss /= size
correct /= size
return train_loss, correct
10 定义测试函数
def test(dataloader, model):
size = len(dataloader.dataset)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= size
correct /= size
return test_loss, correct
11 开始训练
epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
epoch_loss, epoch_acc = train(train_dl, model, loss_fn, optimizer)
epoch_test_loss, epoch_test_acc = test(test_dl, model)
train_loss.append(epoch_loss)
train_acc.append(epoch_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
template = ("epoch:{:2d}/{:2d}, train_loss: {:.5f}, train_acc: {:.1f}% ,"
"test_loss: {:.5f}, test_acc: {:.1f}%")
print(template.format(
epoch+1,epochs, epoch_loss, epoch_acc*100, epoch_test_loss, epoch_test_acc*100))
print("Done!")
输出:
epoch: 1/50, train_loss: 0.02354, train_acc: 68.0% ,test_loss: 0.00537, test_acc: 70.7%
epoch: 2/50, train_loss: 0.01930, train_acc: 71.9% ,test_loss: 0.00437, test_acc: 74.9%
epoch: 3/50, train_loss: 0.01598, train_acc: 76.0% ,test_loss: 0.00366, test_acc: 78.5%
······
epoch:48/50, train_loss: 0.00455, train_acc: 91.7% ,test_loss: 0.00111, test_acc: 92.0%
epoch:49/50, train_loss: 0.00451, train_acc: 91.8% ,test_loss: 0.00110, test_acc: 92.1%
epoch:50/50, train_loss: 0.00449, train_acc: 91.8% ,test_loss: 0.00109, test_acc: 92.1%
Done!
12 绘制损失曲线并保存
plt.plot(range(1, epochs+1), train_loss, label='train_loss', lw=2)
plt.plot(range(1, epochs+1), test_loss, label='test_loss', lw=2, ls="--")
plt.xlabel('epoch')
plt.legend()
plt.savefig('pics/2-4-3.jpg', dpi=400)
输出:
13 绘制准确率曲线并保存
plt.plot(range(1, epochs+1), train_acc, label='train_acc', lw=2)
plt.plot(range(1, epochs+1), test_acc, label='test_acc', lw=2, ls="--")
plt.xlabel('epoch')
plt.legend()
plt.savefig('pics/2-4-4.jpg', dpi=400)
输出: