机器如何读懂人心:Keras实现Self-Attention文本分类

640?wx_fmt=jpeg


作者 | 小宋是呢

转载自CSDN博客


一、Self-Attention概念详解


了解了模型大致原理,我们可以详细的看一下究竟Self-Attention结构是怎样的。其基本结构如下

640?wx_fmt=jpeg

对于self-attention来讲,Q(Query), K(Key), V(Value)三个矩阵均来自同一输入,首先我们要计算Q与K之间的点乘,然后为了防止其结果过大,会除以一个尺度标度 640?wx_fmt=svg ,其中 640?wx_fmt=svg 为一个query和key向量的维度。再利用Softmax操作将其结果归一化为概率分布,然后再乘以矩阵V就得到权重求和的表示。该操作可以表示为 


640?wx_fmt=svg


这里可能比较抽象,我们来看一个具体的例子(图片来源于https://jalammar.github.io/illustrated-transformer/),该博客讲解的极其清晰,强烈推荐),假如我们要翻译一个词组Thinking Machines,其中Thinking的输入的embedding vector用 640?wx_fmt=svg 表示,Machines的embedding vector用 640?wx_fmt=svg 表示。


640?wx_fmt=jpeg


当我们处理Thinking这个词时,我们需要计算句子中所有词与它的Attention Score,这就像将当前词作为搜索的query,去和句子中所有词(包含该词本身)的key去匹配,看看相关度有多高。我们用 640?wx_fmt=svg 代表Thinking对应的query vector, 640?wx_fmt=svg 及 640?wx_fmt=svg 分别代表Thinking以及Machines对应的key vector,则计算Thinking的attention score的时候我们需要计算 640?wx_fmt=svg 与 640?wx_fmt=svg 的点乘,同理,我们计算Machines的attention score的时候需要计算640?wx_fmt=svg 与 640?wx_fmt=svg 的点乘。如上图中所示我们分别得到了640?wx_fmt=svg 与 640?wx_fmt=svg 的点乘积,然后我们进行尺度缩放与softmax归一化,如下图所示:


640?wx_fmt=jpeg


显然,当前单词与其自身的attention score一般最大,其他单词根据与当前单词重要程度有相应的score。然后我们在用这些attention score与value vector相乘,得到加权的向量。


640?wx_fmt=jpeg


如果将输入的所有向量合并为矩阵形式,则所有query, key, value向量也可以合并为矩阵形式表示:


640?wx_fmt=jpeg


其中 640?wx_fmt=svg 是我们模型训练过程学习到的合适的参数。上述操作即可简化为矩阵形式:


640?wx_fmt=jpeg

 二、Self_Attention模型搭建

 

笔者使用Keras来实现对于Self_Attention模型的搭建,由于网络中间参数量比较多,这里采用自定义网络层的方法构建Self_Attention。

Keras实现自定义网络层。需要实现以下三个方法:(注意input_shape是包含batch_size项的

  • build(input_shape): 这是你定义权重的地方。这个方法必须设 self.built = True,可以通过调用 super([Layer], self).build() 完成。

  • call(x): 这里是编写层的功能逻辑的地方。你只需要关注传入 call 的第一个参数:输入张量,除非你希望你的层支持masking。

  • compute_output_shape(input_shape): 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状。

实现代码如下:

 
 

这里可以对照一中的概念讲解来理解代码


如果将输入的所有向量合并为矩阵形式,则所有query, key, value向量也可以合并为矩阵形式表示


640?wx_fmt=jpeg

上述内容对应

 
 

其中 640?wx_fmt=svg 是我们模型训练过程学习到的合适的参数。上述操作即可简化为矩阵形式:

640?wx_fmt=jpeg

上述内容对应(为什么使用batch_dot呢?这是由于input_shape是包含batch_size项的

 
 

这里 QK = QK / (64**0.5) 是除以一个归一化系数,(64**0.5)是笔者自己定义的,其他文章可能会采用不同的方法。 


三、训练网络


项目完整代码如下,这里使用的是Keras自带的imdb影评数据集。

 
 

 

四、结果输出 

640?wx_fmt=png

参考链接:

https://zhuanlan.zhihu.com/p/47282410

原文链接

https://blog.csdn.net/xiaosongshine/article/details/90600028


(*本文为 AI科技大本营转载文章,转载请联系原作者)



精彩推荐



大会开幕倒计时8天!


2019以太坊技术及应用大会特邀以太坊创始人V神与众多海内外知名技术专家齐聚北京,聚焦区块链技术,把握时代机遇,深耕行业应用,共话以太坊2.0新生态。即刻扫码,享优惠票价。


640?wx_fmt=png


640?wx_fmt=png


推荐阅读



640?wx_fmt=png 你点的每个“在看”,我都认真当成了喜欢
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值