# 1. 加载数据集
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 2. 下载 mnist 数据集
trainsets = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()) # 格式转换
testsets = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
class_names = trainsets.classes # 查看类别/标签
print(class_names)
# 3. 查看数据集的大小shape
print(trainsets.data.shape)
print(trainsets.targets.shape)
print(testsets.data.shape)
print(testsets.targets.shape)
# 4. 定义超参数
BATCH_SIZE = 32 # 每批读取的数据大小
EPOCHS = 10 # 训练 10 轮
# 5. 创建数据集的可迭代对象,也就是说一个batch 一个batch的读取数据
train_loader = torch.utils.data.DataLoader(dataset=trainsets, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=testsets, batch_size=BATCH_SIZE, shuffle=True)
images, labels = next(iter(test_loader)) # 查看一批batch的数据
print(images.shape)
print(labels.shape)
# 6. 定义函数:显示一批数据
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406]) # 均值
std = np.array([0.229, 0.224, 0.225]) # 标准差
inp = std * inp + mean
inp = np.clip(inp, 0, 1) # 限速值限制在0-1之间
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
# 网格显示
out = torchvision.utils.make_grid(images)
imshow(out)
# 7. 定义RNN模型
class RNN_Model(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(RNN_Model, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first=True, nonlinearity='relu')
# 全连接层
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# (layer_dim, batch_size, hidden_dim)
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
# 分离隐藏状态,避免梯度爆炸
out, hn = self.rnn(x, h0.detach())
out = self.fc(out[:, -1, :])
return out
# 8. 初始化模型
input_dim = 28 # 输入维度
hidden_dim = 100 # 隐层的维度
layer_dim = 2 # 2层RNN
output_dim = 10 # 输出维度
# 判断是否有GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RNN_Model(input_dim, hidden_dim, layer_dim, output_dim).to(device)
# 9. 定义损失函数
criterion = nn.CrossEntropyLoss()
# 10. 定义优化器
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
# 11. 输出模型参数信息
length = len(list(model.parameters()))
# 12. 循环打印模型参数
for i in range(length):
print('参数: %d'%(i+1))
print(list(model.parameters())[i].size())
# 13. 模型训练
sequence_dim = 28 # 序列长度
loss_list = [] # 保存loss
accuracy_list = [] # 保存accuracy
iteration_list = [] # 保存循环次数
iter = 0
for epoch in range(EPOCHS):
for i, (images, labels) in enumerate(train_loader):
model.train() # 声明训练
images = images.view(-1, sequence_dim, input_dim). requires_grad_().to(device)
labels = labels.to(device)
# 梯度清零(否则会不断累加)
optimizer.zero_grad()
# 前向传播
outputs = model(images)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
# 计数器自动加 1
iter += 1
# 模型验证
if iter % 500 == 0:
model.eval() # 声明
# 计算验证的accuracy
correct = 0.0
total = 0.0
# 迭代测试集,获取数据,预测
for images, labels in test_loader:
images = images.view(-1, sequence_dim, input_dim).to(device)
# 模型预测
outputs = model(images)
# 获取预测概率最大值的下标
predict = torch.max(outputs.data, 1)[1]
# 统计测试集的大小
total += labels.size(0)
# 统计判断/预测正确的数量
if torch.cuda.is_available():
correct += (predict.cuda() == labels.cuda()).sum().item()
else:
correct += (predict == labels).sum().item()
# 计算
accuracy = correct / total * 100
# 保存accuracy, loss, iteration
loss_list.append(loss.data)
accuracy_list.append(accuracy)
iteration_list.append(iter)
# 打印信息
print("loop : {}, Loss : {}, Accuracy : {}".format(iter, loss.item(), accuracy))
# 可视化 loss
plt.plot(iteration_list, loss_list)
plt.xlabel('Number of Iteration')
plt.ylabel('Loss')
plt.title('RNN')
plt.show()
# 可视化 accuracy
plt.plot(iteration_list, accuracy_list, color='r')
plt.xlabel('Number of Iteration')
plt.ylabel('Accuracy')
plt.title('LSTM')
plt.savefig('LSTM_mnist.png')
plt.show()
训练结果:
loop : 500, Loss : 2.304194450378418, Accuracy : 10.26
loop : 1000, Loss : 2.290687322616577, Accuracy : 19.400000000000002
loop : 1500, Loss : 2.279113292694092, Accuracy : 19.07
loop : 2000, Loss : 1.5382373332977295, Accuracy : 42.91
loop : 2500, Loss : 1.4032894372940063, Accuracy : 47.57
loop : 3000, Loss : 0.6646756529808044, Accuracy : 72.8
loop : 3500, Loss : 0.5376549363136292, Accuracy : 82.04
loop : 4000, Loss : 0.6527548432350159, Accuracy : 77.06
loop : 4500, Loss : 0.22894516587257385, Accuracy : 84.63000000000001
loop : 5000, Loss : 0.33490198850631714, Accuracy : 89.14
loop : 5500, Loss : 0.4797677993774414, Accuracy : 89.52
loop : 6000, Loss : 0.283376008272171, Accuracy : 91.72
loop : 6500, Loss : 0.38564950227737427, Accuracy : 92.64
loop : 7000, Loss : 0.036136776208877563, Accuracy : 93.17
loop : 7500, Loss : 0.2951360046863556, Accuracy : 94.28
loop : 8000, Loss : 0.07122373580932617, Accuracy : 93.97999999999999
loop : 8500, Loss : 0.2584732472896576, Accuracy : 94.86
loop : 9000, Loss : 0.25881877541542053, Accuracy : 93.89999999999999
loop : 9500, Loss : 0.13154897093772888, Accuracy : 95.30999999999999
loop : 10000, Loss : 0.17995546758174896, Accuracy : 95.48
loop : 10500, Loss : 0.2594304084777832, Accuracy : 95.42
loop : 11000, Loss : 0.06235146522521973, Accuracy : 95.42
loop : 11500, Loss : 0.03526287525892258, Accuracy : 96.39999999999999
loop : 12000, Loss : 0.4116947650909424, Accuracy : 94.85
loop : 12500, Loss : 0.036189839243888855, Accuracy : 96.6
loop : 13000, Loss : 0.2917410433292389, Accuracy : 95.14
loop : 13500, Loss : 0.053200021386146545, Accuracy : 96.5
loop : 14000, Loss : 0.036753542721271515, Accuracy : 96.75
loop : 14500, Loss : 0.18110425770282745, Accuracy : 96.73
loop : 15000, Loss : 0.16734498739242554, Accuracy : 96.24000000000001
loop : 15500, Loss : 0.2706497013568878, Accuracy : 97.22
loop : 16000, Loss : 0.1784251183271408, Accuracy : 97.0
loop : 16500, Loss : 0.03909716010093689, Accuracy : 97.05
loop : 17000, Loss : 0.09333514422178268, Accuracy : 96.64
loop : 17500, Loss : 0.17319414019584656, Accuracy : 96.31
loop : 18000, Loss : 0.20184077322483063, Accuracy : 96.48
loop : 18500, Loss : 0.00786609947681427, Accuracy : 97.28
可视化结果: