从前,我对RNN的概念仅停留在它可以处理序列数据,进行机器翻译、文本生成、股票预测这些任务,但具体是怎么进行的,数据如何输入,具体的网络结构是什么,都不是很清晰。
现在以LSTM为例,介绍我的理解过程。
参考知乎的一个回答,LSTM的三种不同架构图及映射关系如下所示。
上图图3中对应的6个输入实际上为1个长度为6的词向量,结合一个具体的例子来看,参考csdn的一篇博客 ,假设我们现在有这两句话:
sentences = ["i love you","she loves me"]
把这两句话当作1个batch,time_step=3(即一句话的长度,如果长度不同需补到相同长度)
>>> batch_size=2
time_step1: i she
time_step2: love loves
time_step3: you me
假设我们现在的词表就只有上述出现过的单词
word_list=['i','love','you','she','loves','me']
word_dict={'i':0,'love':1,'you':2,'she':3,'loves':4,'me':5}
将这些单词转化为one-hot的词向量,维度为6
time_step1实际上就是:
array([[1,0,0,0,0,0], # i
[0,0,0,1,0,0]]) # she
这就是batch_size为2,time_step1输送给网络的值。
现在结合上面的LSTM架构图,图3实际上就是对应batch_size只有1的情况下,某个time_step输入数据以及这一层LSTM的结构示意图。这样一来RNN的输入[batch_size, time_step, input_size]也很好理解了。