Pointer-network理论及tensorflow实战

数据下载地址:链接:https://pan.baidu.com/s/1nwJiu4T 密码:6joq
本文代码地址:https://github.com/princewen/tensorflow_practice/tree/master/myPtrNetwork

1、什么是pointer-network

Pointer Networks 是发表在机器学习顶级会议NIPS 2015上的一篇文章,其作者分别来自Google Brain和UC Berkeley。

Pointer Networks 也是一种seq2seq模型。他在attention mechanism的基础上做了改进,克服了seq2seq模型中“输出严重依赖输入”的问题。

什么是“输出严重依赖输入”呢?

论文里举了个例子,给定一些二维空间中[0,1]*[1,0]范围内的点,求这些点的凸包(convex hull)。凸包是凸优化里的重要概念,含义如下图所示,通俗来讲,即找到几个点能把所有点“包”起来。比如,模型的输入是序列{P1,P2,...,P7},输出序列是凸包{P2,P4,P3,P5,P6,P7,P2}。到这里,“输出严重依赖输入”的意思也就明了了,即输出{P2,P4,P3,P5,P6,P7,P2}是从输入序列{P1,P2,...,P7}中提取出来的。换个输入,如{P1,....,P1000},那么输出序列就是从{P1,....,P1000}里面选出来。用论文中的语言来描述,即{P1,P2,...,P7}和{P1,....,P1000}的凸包,输出分别依赖于输入的长度,两个问题求解的target class不一样,一个是7,另一个是1000。

Pointer Network在求凸包上的效果如何呢?

从Accuracy一栏可以看到,Ptr-net明显优于LSTM和LSTM+Attention。

为啥叫pointer network呢?

前面说到,对于凸包的求解,就是从输入序列{P1,....,P1000}中选点的过程。选点的方法就叫pointer,他不像attetion mechanism将输入信息通过encoder整合成context vector,而是将attention转化为一个pointer,来选择原来输入序列中的元素。

与attention的区别:如果你也了解attention的原理,可以看看pointer是如何修改attention的?如果不了解,这一部分就可以跳过了。

首先搬出attention mechanism的公式,前两个公式是整合encoder和decoder的隐式状态,学出来encoder、decoder隐式状态与当前输出的权重关系a,然后根据权重关系a和隐式状态e得到context vector用来预测下一个输出。

Pointer Net没有最后一个公式,即将权重关系a和隐式状态整合为context vector,而是直接进行通过softmax,指向输入序列选择中最有可能是输出的元素。

如果你对上面的理论还没有理解的很到位,那么我们通过代码来进一步讲解,相信你通过这段代码,可以对Ptr的理论有一个更深入的认识。

2、pointer-network实现

这段代码源自:https://github.com/devsisters/pointer-network-tensorflow
上面的代码 实现比较复杂,连下载数据的过程都有,真的是十分费劲,我直接把数据下载好了,上传到百度云上了,大家可以自行下载(地址见文章开头)。

代码目录如下:

config.py 定义了模型的配置data_util.py 定义了数据处理过程main.py 模型的主入口,定义了模型的训练过程model.py 定义了我们的pointer-network模型

我们这里主要讲解我们的数据处理和模型定义两个文件

2.1 数据处理

好了,我们来看看我们的数据吧:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值