15.上一次rnn学习的加工【代码详解】

'''
Author: 365JHWZGo
Description: 15.上一次rnn学习的加工
Date: 2021/10/30 19:45
FilePath: day11-2.py
'''

# 导包
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data as Data

torch.manual_seed(1)

# 超参数
BATCH_SIZE = 64         # 一遍中每一次只训练64张图片
EPOCH = 1               #  对2000张照片只训练1遍
TIME_STEP = 28          # 对一张图片进行逐行扫描,一共扫描28行
INPUT_SIZE = 28         # 逐行扫描28个像素点
LR = 0.01               # optimizer的学习效率为0.01
DOWNLOAD_MNIST = False  # 是否要下载MNIST

# 判断MNIST数据集是否已经在我的目录里,即是否已经下载好
if not (os.path.exists('../mnist')) or not os.listdir('../mnist'):
    DOWNLOAD_MNIST = True

# 从MNIST中获取train_data
train_data = torchvision.datasets.MNIST(
    root='../mnist',                                # 在哪里寻找MNIST
    download=DOWNLOAD_MNIST,                        # 是否要下载MNIST
    train=True,                                     # 是否为训练数据
    transform=torchvision.transforms.ToTensor()     # 将train_data做一些转化【将PIL(python image library)图片转化为某种格式】
                                                    # 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示
)

# 建立train_data的数据加载器,方便分批训练
train_loader = Data.DataLoader(
    dataset=train_data,             # 要加载的数据集
    num_workers=2,                  # 多线程工作
    shuffle=True,                   # 是否需要随机打乱数据集
    batch_size=BATCH_SIZE           # 每次加载的数量
)

# 与train_data相同
test_data = torchvision.datasets.MNIST(
    train=False,
    root='../mnist'
)

# 选取测试数据集中的一部分数据  test_data.test_data->图片数据  test_data.test_labels->标签数据
test_x = test_data.test_data.type(torch.FloatTensor)[:2000] / 255.
# print('test_x',test_x.shape)        # test_x torch.Size([2000, 28, 28])
test_y = test_data.test_labels[:2000]


# rnn
class RNN(torch.nn.Module):
    def __init__(self):
        super(RNN, self).__init__()             # 继承自torch.nn.Module里的__init__()
        self.rnn = torch.nn.LSTM(
            input_size=INPUT_SIZE,              # 输入数据大小,即每次输入一行的28个像素点
            hidden_size=64,                     # 64个隐藏神经元
            num_layers=1,                       # 有几层RNN,层数越多耗时越长
            batch_first=True                    # 是否将batch_size作为第一个维度
        )
        self.out = torch.nn.Linear(64, 10)

    def forward(self, x):
        r_out, (h_n, h_c) = self.rnn(x, None)   # r_out里边包含了batch_size,time_step,output_size
        out = self.out(r_out[:, -1, :])         # 输出最后一步的隐藏状态
        return out


# 创造实例
rnn = RNN()

# 优化器
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)

# 损失函数
loss_func = torch.nn.CrossEntropyLoss()

if __name__ == '__main__':
    # 训练
    for epoch in range(EPOCH):
        for step, (x, batch_y) in enumerate(train_loader):
            # print('x',x.shape)  #x torch.Size([64, 1, 28, 28])    #batch_size,channels,width,height
            # print('----------------------------------')
            batch_x = x.view(-1, 28, 28)
            # print('batch',batch_x.shape)    #batch torch.Size([64, 28, 28])
            output = rnn(batch_x)
            loss = loss_func(output, batch_y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % 50 == 0:
                test_out = rnn(test_x)
                pred_y = torch.max(test_out, 1)[1].data.numpy()
                # accuracy = sum(pred_y == test_y.data.numpy()) / float(test_y.size(0))
                accuracy = float(sum((pred_y == test_y.data.numpy()).astype(int))) / float(test_y.size(0))
                print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

365JHWZGo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值