图片如下:
# RNN循环神经网络 分类 (时间顺序,图片从上往下读取)
import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch.utils.data as Data
torch.manual_seed(1)
# Hyper parameters
#图片大小28*28,一共读取28次,每次读取一行28个数据点
EPOCH=1
BATCH_SIZE=64
TIME_STEP=28 #RNN time step/image_height
INPUT_SIZE=28 #每个时间点包含多少个数据点, image_width
LR=0.01
DOWNLOAD_MNIST=False
train_data=torchvision.datasets.MNIST(
root='./mnist/',
train=True,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST,
)
# plt.imshow(train_data.data[0].numpy(),cmap='gray')
# plt.title(train_data.targets[0].numpy())
# plt.show()
test_data=torchvision.datasets.MNIST(root='./mnist/',train=False)
#批训练 50samples,1 channel,28*28 (50,1,28,28)
train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
#去前2000个测试数据进行测试
test_x=torch.unsqueeze(test_data.data,dim=1).type(torch.FloatTensor)[:2000]/255 #shape from(2000,28,28)to(2000,1,28,28)
test_y=test_data.targets[:2000]
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn=nn.LSTM(
input_size=28, #每行的像素点
hidden_size=64,
num_layers=1, # the num of RNN layers
batch_first=True, # 表示在输入数据的时候,数据的维度为(batch,time_step,input)
#如果自己定义的维度为(time_step,batch,input),则为False
)
self.out=nn.Linear(64,10)
def forward(self,x): # x :(batch,time_step,input_size)
# 每步对输入的x进行计算,每次计算完之后产生一个自己生成的理解;下一次传入神经网络的不只是这次的input,还有上一次生成的(h_n,h_c)
# 然后生成输出结果r_out和自己的理解(h_n,h_c); (h_n,h_c)为分线和主线层的hidden state
# None:第一次输入的时候没有hidden state
# r_out 中有从第一次到最后一次一共28个output
r_out,(h_n,h_c)=self.rnn(x,None) # r_out: (batch,time_step,output_size)
# h_n,h_c: (n_layers,batch,hidden_size)
#选取最后一次的r_out进行输出
#r_out[:,-1,:]的值也是h_n的值
out=self.out(r_out[:,-1,:])
return out
rnn=RNN()
# print(rnn)
'''
RNN(
(rnn): LSTM(28, 64, batch_first=True)
(out): Linear(in_features=64, out_features=10, bias=True)
)
'''
# 训练和测试
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss() # 分类标签不是one-hot的形式
for epoch in range(EPOCH):
for step,(x,b_y) in enumerate(train_loader):
b_x=x.view(-1,28,28) # reshape x to (batch,time_step,input_size)
output=rnn(b_x)
loss=loss_func(output,b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%50==0:
x = test_x.view(-1,28,28)
test_out = rnn(x)
pred_y = torch.max(test_out, 1)[1].data.squeeze() # torch.max(input,dim)返回每一行中的最大值的标签
accuracy = (pred_y == test_y).numpy().sum() / test_y.size(0)
print('step: {} | train loss: {} | test accuracy: {} '.format(step, loss.data, accuracy))
# 取前十个数据看一下结果
test_out=rnn(test_x[:10].view(-1,28,28))
pred_y=torch.max(test_out,1)[1].data.squeeze()
print('prediction:',pred_y)
print('real value:',test_y[:10])
#结果:
step: 0 | train loss: 2.2883260250091553 | test accuracy: 0.1025
step: 50 | train loss: 1.1060410737991333 | test accuracy: 0.6095
step: 100 | train loss: 0.8917801380157471 | test accuracy: 0.739
step: 150 | train loss: 0.5513193607330322 | test accuracy: 0.811
step: 200 | train loss: 0.2461433708667755 | test accuracy: 0.883
step: 250 | train loss: 0.2749921679496765 | test accuracy: 0.854
step: 300 | train loss: 0.21603426337242126 | test accuracy: 0.896
step: 350 | train loss: 0.4433455169200897 | test accuracy: 0.9125
step: 400 | train loss: 0.23373939096927643 | test accuracy: 0.928
step: 450 | train loss: 0.15998145937919617 | test accuracy: 0.93
step: 500 | train loss: 0.06220763549208641 | test accuracy: 0.934
step: 550 | train loss: 0.11671080440282822 | test accuracy: 0.954
step: 600 | train loss: 0.08472592383623123 | test accuracy: 0.947
step: 650 | train loss: 0.1787828803062439 | test accuracy: 0.9495
step: 700 | train loss: 0.23303718864917755 | test accuracy: 0.9495
step: 750 | train loss: 0.04452371597290039 | test accuracy: 0.9475
step: 800 | train loss: 0.09064620733261108 | test accuracy: 0.958
step: 850 | train loss: 0.1400453895330429 | test accuracy: 0.956
step: 900 | train loss: 0.1921684443950653 | test accuracy: 0.9645
prediction: tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
real value: tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9])
注:
1. LSTM(Long Short-Term Memory):长短期记忆网络, 某一时刻的输出与此时的输入和上一时刻的输出结果有关。
2. LSTM相当于延缓网络记忆衰退的工具
3. LSTM RNN网络有主线和分线两个部分:
① 分线包含三个controller:输入控制,输出控制,忘记控制
② 输入controller:若输入信息对结果十分重要,输入controller就会将这个分线信息按照重要程度写入主线,进行分析。
③ 忘记controller:若此时的分线信息对结果产生了影响,忘记控制器就会将之前的某些主线信息忘记,按比例替换为现在的新信息。
④ 输出controller: 基于当前的主线信息和分线信息,判断输出的内容。
4. 主线信息的更新取决于输入和忘记controller