RNN实现MNIST数据集分类

# 1. 加载数据集

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# 2. 下载 mnist 数据集

trainsets = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())  # 格式转换
testsets = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

class_names = trainsets.classes  # 查看类别/标签
print(class_names)

# 3. 查看数据集的大小shape

print(trainsets.data.shape)  

print(trainsets.targets.shape)

print(testsets.data.shape)

print(testsets.targets.shape)

# 4. 定义超参数

BATCH_SIZE = 32  # 每批读取的数据大小
EPOCHS = 10  # 训练 10 轮

# 5. 创建数据集的可迭代对象,也就是说一个batch 一个batch的读取数据

train_loader = torch.utils.data.DataLoader(dataset=trainsets, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=testsets, batch_size=BATCH_SIZE, shuffle=True)

images, labels = next(iter(test_loader))  # 查看一批batch的数据

print(images.shape)

print(labels.shape)

# 6. 定义函数:显示一批数据

def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])  # 均值
    std = np.array([0.229, 0.224, 0.225])  # 标准差
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)  # 限速值限制在0-1之间
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)    

# 网格显示
out = torchvision.utils.make_grid(images)
imshow(out)

# 7. 定义RNN模型

class RNN_Model(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(RNN_Model, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first=True, nonlinearity='relu')
        # 全连接层
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        # (layer_dim, batch_size, hidden_dim)
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
        # 分离隐藏状态,避免梯度爆炸
        out, hn = self.rnn(x, h0.detach())
        out = self.fc(out[:, -1, :])
        return out     

# 8. 初始化模型

input_dim = 28  # 输入维度
hidden_dim = 100  # 隐层的维度
layer_dim = 2  # 2层RNN
output_dim = 10  # 输出维度

# 判断是否有GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = RNN_Model(input_dim, hidden_dim, layer_dim, output_dim).to(device)

# 9. 定义损失函数

criterion = nn.CrossEntropyLoss()

# 10. 定义优化器

learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

# 11. 输出模型参数信息

length = len(list(model.parameters()))

# 12. 循环打印模型参数

for i in range(length):
    print('参数: %d'%(i+1))
    print(list(model.parameters())[i].size())

# 13. 模型训练

sequence_dim = 28  # 序列长度
loss_list = []  # 保存loss
accuracy_list = []  # 保存accuracy
iteration_list = []  # 保存循环次数

iter = 0
for epoch in range(EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        model.train()  # 声明训练
        images = images.view(-1, sequence_dim, input_dim). requires_grad_().to(device)
        labels = labels.to(device)
                
        # 梯度清零(否则会不断累加)
        optimizer.zero_grad()
        # 前向传播
        outputs = model(images)
        # 计算损失
        loss = criterion(outputs, labels)
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 计数器自动加 1
        iter += 1
        # 模型验证
        if iter % 500 == 0:
            model.eval()  # 声明
            # 计算验证的accuracy
            correct = 0.0
            total = 0.0
            # 迭代测试集,获取数据,预测
            for images, labels in test_loader:
                images = images.view(-1, sequence_dim, input_dim).to(device)
                # 模型预测
                outputs = model(images)
                # 获取预测概率最大值的下标
                predict = torch.max(outputs.data, 1)[1]
                # 统计测试集的大小
                total += labels.size(0)
                # 统计判断/预测正确的数量
                if torch.cuda.is_available():
                    correct += (predict.cuda() == labels.cuda()).sum().item()
                else:
                    correct += (predict == labels).sum().item()
                    
            # 计算
            accuracy = correct / total * 100
            # 保存accuracy, loss, iteration
            loss_list.append(loss.data)
            accuracy_list.append(accuracy)
            iteration_list.append(iter)
            # 打印信息
            print("loop : {}, Loss : {}, Accuracy : {}".format(iter, loss.item(), accuracy))

# 可视化 loss

plt.plot(iteration_list, loss_list)
plt.xlabel('Number of Iteration')
plt.ylabel('Loss')
plt.title('RNN')
plt.show()

# 可视化 accuracy

plt.plot(iteration_list, accuracy_list, color='r')
plt.xlabel('Number of Iteration')
plt.ylabel('Accuracy')
plt.title('LSTM')
plt.savefig('LSTM_mnist.png')
plt.show()

训练结果:

loop : 500, Loss : 2.304194450378418, Accuracy : 10.26
loop : 1000, Loss : 2.290687322616577, Accuracy : 19.400000000000002
loop : 1500, Loss : 2.279113292694092, Accuracy : 19.07
loop : 2000, Loss : 1.5382373332977295, Accuracy : 42.91
loop : 2500, Loss : 1.4032894372940063, Accuracy : 47.57
loop : 3000, Loss : 0.6646756529808044, Accuracy : 72.8
loop : 3500, Loss : 0.5376549363136292, Accuracy : 82.04
loop : 4000, Loss : 0.6527548432350159, Accuracy : 77.06
loop : 4500, Loss : 0.22894516587257385, Accuracy : 84.63000000000001
loop : 5000, Loss : 0.33490198850631714, Accuracy : 89.14
loop : 5500, Loss : 0.4797677993774414, Accuracy : 89.52
loop : 6000, Loss : 0.283376008272171, Accuracy : 91.72
loop : 6500, Loss : 0.38564950227737427, Accuracy : 92.64
loop : 7000, Loss : 0.036136776208877563, Accuracy : 93.17
loop : 7500, Loss : 0.2951360046863556, Accuracy : 94.28
loop : 8000, Loss : 0.07122373580932617, Accuracy : 93.97999999999999
loop : 8500, Loss : 0.2584732472896576, Accuracy : 94.86
loop : 9000, Loss : 0.25881877541542053, Accuracy : 93.89999999999999
loop : 9500, Loss : 0.13154897093772888, Accuracy : 95.30999999999999
loop : 10000, Loss : 0.17995546758174896, Accuracy : 95.48
loop : 10500, Loss : 0.2594304084777832, Accuracy : 95.42
loop : 11000, Loss : 0.06235146522521973, Accuracy : 95.42
loop : 11500, Loss : 0.03526287525892258, Accuracy : 96.39999999999999
loop : 12000, Loss : 0.4116947650909424, Accuracy : 94.85
loop : 12500, Loss : 0.036189839243888855, Accuracy : 96.6
loop : 13000, Loss : 0.2917410433292389, Accuracy : 95.14
loop : 13500, Loss : 0.053200021386146545, Accuracy : 96.5
loop : 14000, Loss : 0.036753542721271515, Accuracy : 96.75
loop : 14500, Loss : 0.18110425770282745, Accuracy : 96.73
loop : 15000, Loss : 0.16734498739242554, Accuracy : 96.24000000000001
loop : 15500, Loss : 0.2706497013568878, Accuracy : 97.22
loop : 16000, Loss : 0.1784251183271408, Accuracy : 97.0
loop : 16500, Loss : 0.03909716010093689, Accuracy : 97.05
loop : 17000, Loss : 0.09333514422178268, Accuracy : 96.64
loop : 17500, Loss : 0.17319414019584656, Accuracy : 96.31
loop : 18000, Loss : 0.20184077322483063, Accuracy : 96.48
loop : 18500, Loss : 0.00786609947681427, Accuracy : 97.28

可视化结果:在这里插入图片描述在这里插入图片描述

  • 2
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值