本例开始使用NLP中的经典结构-RNN。首先构建了ElmanRnn(其实pytorch自带这个网络,作者为了读者能更清楚的了解RNN的结构,所以额外构造了它),在此基础上搭建分类器。
待解决的问题:根据人名识别所属国家
目录
网络结构和数据维度
Column gather 函数
1. 这个函数很有意思,配合RNN网络使用,目的是为了取出序列输出的最后一个词向量,要注意的是单词长度不同,长度较短的单词向量都在尾部用mask填充,所以利用参数x_length保证不会取到mask。
2. 为什么只取最后一个词向量呢?因为经过RNN的序列处理,单词的所有contextual信息都被记忆在了最后一个词向量,比如john这个名字,j,o,h的信息都被或多或少地保存在n的向量中,所以只取出n的向量再传入FC层进行分类就足够了。
3. 注意第12行,使用了detach(),到目前为止,只知道会将该节点从图中分离,形成叶节点,从而造成梯度不回传。(但似乎不适用这个方法也没什么影响)有待进一步调研
def column_gather(y_out, x_length):
"""Get a specific vector fr