关于RNN的理解说明,以下视频阐述的非常清晰明了。
【循环神经网络】5分钟搞懂RNN,3D动画深入浅出_哔哩哔哩_bilibili
数据导入和训练过程类似于CNN,在此只介绍RNN分类器的搭建:
# 搭建RNN分类器
class RNNimc(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
"""
:param input_dim: 输入数据的维度(图片每行的数据像素点)
:param hidden_dim:RNN神经元个数
:param layer_dim:RNN的层数
:param output_dim:隐藏层输出的维度(分类的数量)
"""
super(RNNimc, self).__init__()
self.hidden_dim = hidden_dim # RNN神经元个数
self.layer_dim = layer_dim # RNN的层数
# RNN
self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first=True, nonlinearity='relu')
# 全连接层
self.fcl = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# x:(64, 28, 28), (batch_num, token_num, word vector length)
out, h_n = self.rnn(x, None) # out:(64, 28, 128) h_n:(1, 64, 128)
out = self.fcl(out[:, -1, :]) # out:(64, 10) 以最后时刻输出的结果为准
return out
# 定义参数并调用
input_dim = 28
hidden_dim = 128
layer_dim = 1
output_dim = 10
MyRNNimc = RNNimc(input_dim, hidden_dim, layer_dim, output_dim)
print(MyRNNimc)
该图给出了RNN基本参数的含义。对于图片的识别,一张图片可以类比一个句子,都是一个二维向量,输入相同。