PyTorch笔记 入门:写一个简单的神经网络4:RNN(以MNIST数据集为例)

相关视频:
PyTorch 动态神经网络 (莫烦 Python 教学)
【李宏毅】2020 最新课程 (完整版) Machine Learning (2020)

强烈推荐李宏毅的课,RNN、LSTM讲得很清晰。

一、导入库、设置超参数

在这里插入图片描述

二、下载、读取MNIST数据集、DataLoader

详细见上一篇文章:PyTorch笔记 入门:写一个简单的神经网络3:CNN(以MNIST数据集为例)

在这里插入图片描述在这里插入图片描述

三、创建网络

在这里插入图片描述

四、优化器、损失函数

在这里插入图片描述

五、训练网络

在这里插入图片描述

六、预测

在这里插入图片描述

七、完整代码

import torch
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 创建网络
class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()
        self.lstm_layer = nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=64,
            num_layers=1,
            batch_first=True, # 若为True,则输入数据的维度为(batch_size, time_step, input_size),否则维度为(time_step, batch_size, input_size)
        )
        self.output_layer = nn.Linear(64, 10)
        
    def forward(self, x):
        out, (n, c) = self.lstm_layer(x, None)
        output = self.output_layer(out[:, -1, :])
        return output

EPOCH = 1
BATCH_SIZE = 64
TIME_STEP = 28 # 一共有28行,每步读取一行
INPUT_SIZE = 28 # 每一行有28列
LR = 0.01
DOWNLOAD = False

# 下载mnist数据
train_data = datasets.MNIST(
    root='./data', # 保存路径
    train=True, # True表示训练集,False表示测试集
    transform=transforms.ToTensor(), # 将0~255压缩为0~1
    download=DOWNLOAD
)

# DataLoader
train_loader = torch.utils.data.DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2
)

# 测试集
test_data = datasets.MNIST(
    root='./data',
    train=False
)

print(test_data.data.size())
print(test_data.targets.size())

# 为了节约时间,只使用测试集的前2000个数据
test_x = Variable(
    torch.unsqueeze(test_data.data, dim=1),
    volatile=True
).type(torch.FloatTensor)[:2000]/255 # 将将0~255压缩为0~1

test_y = test_data.targets[:2000]

rnn = RNN()
print(rnn)

# 优化器
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)

# 损失函数
loss_func = nn.CrossEntropyLoss()

# 训练神经网络
for epoch in range(EPOCH):
    for step, (batch_x, batch_y) in enumerate(train_loader):
        output = rnn(batch_x.reshape(-1, 28, 28))
        loss = loss_func(output, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 每隔50步输出一次信息
        if step%50 == 0:
            test_output = rnn(test_x.reshape(-1, 28, 28))
            predict_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = (predict_y == test_y).sum().item() / test_y.size(0)
            print('Epoch', epoch, '|', 'Step', step, '|', 'Loss', loss.data.item(), '|', 'Test Accuracy', accuracy)
            
# 预测
test_output = rnn(test_x[:100].reshape(-1, 28, 28))
predict_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
real_y = test_y[:100].numpy()
print(predict_y)
print(real_y)

# 打印预测和实际结果
for i in range(10):
    print('Predict', predict_y[i])
    print('Real', real_y[i])
    plt.imshow(test_data.data[i].numpy(), cmap='gray')
    plt.show()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值