'''
Author: 365JHWZGo
Description: 15.上一次rnn学习的加工
Date: 2021/10/30 19:45
FilePath: day11-2.py
'''
# 导包
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data as Data
torch.manual_seed(1)
# 超参数
BATCH_SIZE = 64 # 一遍中每一次只训练64张图片
EPOCH = 1 # 对2000张照片只训练1遍
TIME_STEP = 28 # 对一张图片进行逐行扫描,一共扫描28行
INPUT_SIZE = 28 # 逐行扫描28个像素点
LR = 0.01 # optimizer的学习效率为0.01
DOWNLOAD_MNIST = False # 是否要下载MNIST
# 判断MNIST数据集是否已经在我的目录里,即是否已经下载好
if not (os.path.exists('../mnist')) or not os.listdir('../mnist'):
DOWNLOAD_MNIST = True
# 从MNIST中获取train_data
train_data = torchvision.datasets.MNIST(
root='../mnist', # 在哪里寻找MNIST
download=DOWNLOAD_MNIST, # 是否要下载MNIST
train=True, # 是否为训练数据
transform=torchvision.transforms.ToTensor() # 将train_data做一些转化【将PIL(python image library)图片转化为某种格式】
# 在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示
)
# 建立train_data的数据加载器,方便分批训练
train_loader = Data.DataLoader(
dataset=train_data, # 要加载的数据集
num_workers=2, # 多线程工作
shuffle=True, # 是否需要随机打乱数据集
batch_size=BATCH_SIZE # 每次加载的数量
)
# 与train_data相同
test_data = torchvision.datasets.MNIST(
train=False,
root='../mnist'
)
# 选取测试数据集中的一部分数据 test_data.test_data->图片数据 test_data.test_labels->标签数据
test_x = test_data.test_data.type(torch.FloatTensor)[:2000] / 255.
# print('test_x',test_x.shape) # test_x torch.Size([2000, 28, 28])
test_y = test_data.test_labels[:2000]
# rnn
class RNN(torch.nn.Module):
def __init__(self):
super(RNN, self).__init__() # 继承自torch.nn.Module里的__init__()
self.rnn = torch.nn.LSTM(
input_size=INPUT_SIZE, # 输入数据大小,即每次输入一行的28个像素点
hidden_size=64, # 64个隐藏神经元
num_layers=1, # 有几层RNN,层数越多耗时越长
batch_first=True # 是否将batch_size作为第一个维度
)
self.out = torch.nn.Linear(64, 10)
def forward(self, x):
r_out, (h_n, h_c) = self.rnn(x, None) # r_out里边包含了batch_size,time_step,output_size
out = self.out(r_out[:, -1, :]) # 输出最后一步的隐藏状态
return out
# 创造实例
rnn = RNN()
# 优化器
optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)
# 损失函数
loss_func = torch.nn.CrossEntropyLoss()
if __name__ == '__main__':
# 训练
for epoch in range(EPOCH):
for step, (x, batch_y) in enumerate(train_loader):
# print('x',x.shape) #x torch.Size([64, 1, 28, 28]) #batch_size,channels,width,height
# print('----------------------------------')
batch_x = x.view(-1, 28, 28)
# print('batch',batch_x.shape) #batch torch.Size([64, 28, 28])
output = rnn(batch_x)
loss = loss_func(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
test_out = rnn(test_x)
pred_y = torch.max(test_out, 1)[1].data.numpy()
# accuracy = sum(pred_y == test_y.data.numpy()) / float(test_y.size(0))
accuracy = float(sum((pred_y == test_y.data.numpy()).astype(int))) / float(test_y.size(0))
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
15.上一次rnn学习的加工【代码详解】
最新推荐文章于 2024-09-16 17:36:34 发布