注意力模型的一个实例代码的实现与分析

上一篇文章:关于《注意力模型--Attention注意力机制》的学习 是对注意力模型的理论知识进行学习,这一篇文章将结合,在github上找到的一份基于keras框架实现的可运行的注意模型代码:Attention_Network_With_Keras 进行分析,进一步理解Attention模型。


将jupyter的文件转换为.py文件,方便在Pycharm中运行调试,转换方法如下图:

6102062-b7b645a6792ef881.png
文件转换操作方法




先解决一个问题:

因为Python环境不同,而导致的一个问题,运行程序时会在下图中报一个错:TypeError: softmax() got an unexpected keyword argument 'axis'

6102062-0dedd2595df1f684.png
1,出问题的地方
6102062-7164731e30891bdb.png
2,具体的出错信息

跳转到我目前版本的tensorflow后端softmax()函数没有参数axis

6102062-119a6e89b2879e26.png
3,寻找问题的具体原因
6102062-de919dadcf1f5bc2.png
4,寻找解决问题的方法
6102062-bdc8b21f1d953b41.png
5,解决办法:将K.softmax() 换成 tf.nn.softmax()

问题解决完毕!!!!!




程序分析与理解:

1.获取模型的样子

想获取搭建模型的样子,先按照上图:文件转换操作方法。将文件转换成.py文件,然后在Pycharm中运行程序代码。再参考文章:《kears可视化模块keras.utils.visualize_util 的安装配置与错误解决办法》 按照这篇文章处理后,添加语句:plot_model(model,to_file='AttentionModel.png',show_shapes=True)

6102062-f7d9edb73d125bfc.png
打印模型图片的语句
6102062-7498213951f94ba7.png
模型图1: 已获得 github上Attention_Network_With_Keras 搭建模型的样子


2.对模型的思路进行理解分析

以["six hours and fifty five am","06:55"]实例为例进行模型分析:

问题定义

将人类语言描述的时间,记为X;将标准数字描述的时间,记为Y。即<X,Y>类型,符合Encoder-Decoder框架。  X=["six hours and fifty five am"],Y=["06:55"]    任务:将X通过模型转换成Y

数据处理

对数据进行处理。数据集中<X,Y>句对样例有1万个,数据集在Time Dataset.json文件中。X集合,可以由41个不同的字符构成,将这41个字符存为字典类型;Y集合,可以由11个不同的字符构成,将这11个字符存为字典。其实这两个字典数据就存在Time Vocabs.json文件中。

将X、Y数据处理成索引形式,每一个索引对应于一个one-hot向量。比如:

X="six hours and fifty five am"   len(X)=27  ,模型中设置了X数据中的最大长度为41,索引len(X)=27<41,得进行索引填充(padding)。

X=['s','i','x',' ','h','o','u','r','s',' ','a','n','d',' ','f','i','f','t','y',' ','f','i','v','e',' ','a','m']

去字典human_vocab查询其索引值,并填充到41的长度,于是:

索引X=[31 22 36 0 21 27 33 30 31 0 14 26 17 0 19 22 19 32 37 0 19 22 34 18 0 14 25 40 40 40 40 40 40 40 40 40 40 40 40 40 40]

同理,Y="06:55",len(Y)=5,Y数据集的长度都为5,所以不需要填充。

Y=['0','6',':','5,'5']

去字典machine_vocab查询其索引值,得到:

索引Y=[0 6 10 5 5]

然后再将X索引转换为Xoh(one-hot)形式,Xoh维度:(41x41);将Y索引转换为Yoh(one-hot)形式,Yoh维度:(5x11)。如下图:

6102062-af2aa818eb3521c0.png

模型搭建:

1)通过get_model()函数获得搭建的模型

6102062-456c94555fada628.png

2)get_model()函数理解

6102062-2d5a1e31b2a57ea8.png

3)attention_layer()获取注意力的实现

6102062-eb33f786e3f3009b.png

4)one_step_of_attention()每一步获取注意力的过程

6102062-7ae45b0cbbd38fd4.png

5) 模型训练

6102062-0a11b5d8a46541ca.png

6)对各个网络层进行命名后的模型图片,两个下划线'__'后面的表示这个层在哪个函数中。

6102062-36f668de983eec36.png

7)模型图片划分

6102062-33057bf68e1a01bb.png
展开阅读全文

没有更多推荐了,返回首页