attention(注意力机制)原理和pytorch demo

1 篇文章 0 订阅
1 篇文章 0 订阅

目录

说明

RNN的局限性

注意力机制原理

注意力机制实现

第一步:编码

第二步:第0次打分并解码

第三步:第1次打分并解码

Demo链接和结果分析

总结&改进


说明

demo源自吴恩达老师的课程,从tensorflow修改为pytorch,略有不同。

RNN的局限性

原始数据是一个字符串:friday august 17 2001,长度是21(包含空格),为了简便这里把每一个字符用一个onehot向量表示。于是数据转化为21个onehot向量。依次输入到一个RNN网络(可以是普通RNN、也可以是LSTM和GRU),最终得到一个向量(即RNN网络中的隐状态)。如果此时用这个向量作为整个字符串的编码信息直接去解码,很可能会丢失一些信息,尤其是输入更长的字符串时,更容易丢失信息。并且很难抽取距离较远的两个特征之间的关系。

注意力机制原理

我们的目标是把这个字符串翻译成2001-08-17。想象一下如果是人来进行这个翻译,那么我们会做出如下映射关系,箭头即表示人的注意力机制。神经网络的注意力机制就是在模仿人类。

注意力机制实现

第一步:编码

由于输入序列是不定长,为方便计算,将全部输入都补充到长为30,补充方式为末尾加特定字符,记为<pad>。即friday august 17 2001<pad><pad><pad><pad><pad><pad><pad><pad><pad>

然后把对应的30个onehot向量(对于其他任务,可以是不同的特征向量),依次输入到encoder网络(这里使用双向LSTM)中,每次计算得到的隐状态向量全都保存下来,一共是30个(这里LSTM的隐状态向量长度设为64,由于是双向LSTM,长度一共是128),作为初始特征,记作Feature_30x128,这里30表示时间序列长度。

第二步:第0次打分并解码

此时解码部分RNN网络的隐状态向量H初始为全零(这里向量长度是64),复制30份,然后和Feature_30x128拼起来得到Feature_30x192。然后输入到一个全连接网络,输出是30*1维矩阵,即长为30的向量,最后经过softmax,得到30个打分(softmax的目的是让30个打分之和为1)。

此时有30个长为128的初始特征,即Feature_30x128;以及30个打分,相乘后加起来,得到一个128维的打分后特征,此操作举例如下(为简便,例子中的特征维度不是30*128,是3*4,则分数有3个)。

{\color{Red} {\color{Red} }Feature\_30x128}=\begin{bmatrix} 0.1 & 1.1 &0.7 \\ 0.2 & 0.5 &1.4 \\ 0.4 & 0.3 &0.5 \\ 0.3 & 0.6 &0.2 \end{bmatrix},score=\begin{bmatrix} 0.1 & 0.7 & 0.2 \end{bmatrix}

相乘后如下。

{\color{Red} Feature\_30x128}=\begin{bmatrix} 0.01 & 0.77 &0.14 \\ 0.02 & 0.35 &0.28 \\ 0.04 & 0.21 &0.1 \\ 0.03 & 0.42 &0.04 \end{bmatrix}

然后沿着时间维度相加,得到

{\color{Blue} Feature\_128}=\begin{bmatrix} 0.92 \\ 0.65 \\ 0.35 \\ 0.49 \end{bmatrix}

Feature_128输入到解码部分RNN网络,只向前传播一次,得到新的输出隐状态H,然后在经过一层全连接,进行分类(即输出哪个字符)。

第三步:第1次打分并解码

往后的解码和第二步都一样,只不过H在不断变化,用以拼在Feature_30x128上,指导如何打分。

Demo链接和结果分析

代码链接:https://github.com/zcsxll/date_trans_with_attention

我们实际上一共解码10次(因为2001-08-17这种输出格式长度固定为10),每次都会的到一个长为30的打分,即一张10*30的热图,如下图。一共10行,每一行长是30,是解码对应字符时的打分结果。

从图中可知注意力机制的效果十分明显,当解码月份08时,august部分的打分较大。至于年份部分的打分并非一一对应,是因为训练数据集中,一旦出现零几年,就只有二零零几年。

总结&改进

编码和解码部分都采用了LSTM,其中包含C和H两个隐变量,都可以拼到Feature_30x128上进行打分计算。

输出是分类任务,但实通过实验,对于本实例,训练时采用MSE损失比采用交叉熵收敛得更快。

输入到网络中的特征不应该是单个字符,而应该是单词,例如august,应该作为一个特征向量进行操作,而不是6个。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值