首先第一点RNN提出的初衷是有记忆力。
以下面这个例子来说,如果不考虑Taipei前面的单词(arrive, leave)则对神经网络来说两句话中的taipei是一个意思,神经网络没法区分taipei是目的地还是出发地。所以就产生了RNN,让他可以记住Taipei的上下文。
这种有记忆力的NN就叫做RNN(Recurrent Neural Network)。
李宏毅深度学习课程:李宏毅2020机器学习深度学习(完整版)国语_哔哩哔哩 (゜-゜)つロ 干杯~-bilibili
获得记忆力的具体过程如下:
每次都将隐藏层的输出,存储到内存中来。
下一次,当有Input的时候,神经元就不单单考虑Input中的x1和x2,神经元还会考虑存储到内存中的a1和a2。也就是将x1,x2和a1,a2拼接起来,一起作为输入。
在实际训练中,我们一般会给内存中的值隐层输出一个初始值,一般来说将其全部初始化为0。
RNN还有一个特性,如果调整输入的顺序的话,RNN的输出也会不同。即changing the sequence order will change the output。
如果将输入调整为【2,2】,【1,1】,【1,1】
则输出序列会变为:【8, 8】,【20, 20】,【44,44】
所以在RNN中,RNN会考虑输入的顺序,order。
RNN的计算过程,展开之后如下:
分别将arrive taipei 和 on作为input时(不同时间点)的网络状况。同一个NN在不同的时间点被使用了三次。
我们在用pytorch官方的一个RNN的例子来深入理解上述原理。
在这个例子中,RNN网络的定义和训练过程源码分别如下:
RNN网络的定义,这个网络很简单只有一个隐藏层和一个输出层。
import torch.nn as nn
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_categories)
从上面的forward函数中可以看出,
RNN的第一个操作就是将输入input和内存中的隐层输出hidden 给cat起来。
RNN的第二个操作是计算新的隐层的输出,并保存到hidden中。
网络的训练代码如下:
learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
def train(category_tensor, line_tensor):
hidden = rnn.initHidden()
rnn.zero_grad() #梯度归0
for i in range(line_tensor.size()[0]):
output, hidden = rnn(line_tensor[i], hidden)
loss = criterion(output, category_tensor) #loss计算
loss.backward() #反向传播
# Add parameters' gradients to their values, multiplied by learning rate
for p in rnn.parameters():
p.data.add_(p.grad.data, alpha=-learning_rate) #梯度更新
return output, loss.item()
训练过程的第一步就是调用initHidden函数来初始化内存中保存隐层输出的hidden变量。
关于完整源码可以参考:
https://github.com/spro/practical-pytorch/tree/master/char-rnn-classificationgithub.com https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.htmlpytorch.org