对seq2seq的一些个人理解

因为做毕设用到seq2seq框架,网上关于seq2seq的资料很多,但关于seq2seq的代码则比较少,阅读tensorflow的源码则需要跳来跳去比较麻烦(其实就是博主懒)。踩了很多坑后,形成了一些个人的理解,在这里记录下,如果有人恰好路过,欢迎指出错误~

seq2seq图解如下:
C是状态
上图中,C是encoder输出的最终状态,作为decoder的初始状态;W是encoder的最终输出,作为decoder的初始输入。

具体到tensorflow代码中(tensorflow r1.1.0cpu版本),查阅tf.contrib.rnn.BasicLSTMCell的源码如下:

class BasicLSTMCell(RNNCell):

  def __init__(self, num_units, forget_bias=1.0,input_size=None, state_is_tuple=True, activation=tanh,reuse=None):

      super(BasicLSTMCell, self).__init__(_reuse=reuse)
      if not state_is_tuple:
          logging.warn("%s: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.", self)
      if input_size is not None:
          logging.warn("%s: The input_size parameter is deprecated.", self)
      self._num_units = num_units
      self._forget_bias = forget_bias
      self._state_is_tuple = state_is_tuple
      self._activation = activation

  @property
  def state_size(self):
      return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units)

  @property
  def output_size(self):
      return self._num_units

  def call(self, inputs, state):
      """Long short-term memory cell (LSTM)."""
      # Parameters of gates are concatenated into one multiply for efficiency.
      if self._state_is_tuple:
          c, h = state
      else:
          c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

      concat = _linear([inputs, h], 4 * self._num_units, True)

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)

      new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
      new_h = self._activation(new_c) * sigmoid(o)

      if self._state_is_tuple:
          new_state = LSTMStateTuple(new_c, new_h)
      else:
          new_state = array_ops.concat([new_c, new_h], 1)
      return new_h, new_state

令调用LSTM的命令为:

output,state = tf.contrib.rnn.BasicLSTMCell(input,init_state)

可知,state其实是包含了output在内的。state[0]才是真正的state,即图中的C;state[1]是output,即图中的W。这样一来,最后输出的output其实就显得鸡肋了。(如果要在encode和decode之间搞事情的话,这点就比较重要了。博主就是踩了这个坑。。。当然如果不在这里搞事情的话就可以完美绕过这个坑)

知道这点后,那么接下来的就好理解多了。博主之前曾有过一段时间的疑惑,那就是seq2seq的decode_input到底是什么?如果跟target只是移了一个位,其他完全不变的话,那要encoder干什么?知道了上面的背景后,我们不难知道,教程中decode_input跟target的移位只是加速训练过程。而在具体应用中,decode_input可以是encode的最后一个输出,也可以自己设定一个全零的数组。个人觉得设定全零的数组比较好,因为初始状态就已经包含了encode的最后一个输出了,而且全零数组可以当作是一个开始的标识(至于seq2seq具体的训练过程可视化,可以阅读2017年ACL的一篇文章Visualizing and Understanding Neural Machine Translation http://nlp.csai.tsinghua.edu.cn/~ly/papers/acl2017_dyz.pdf

最后,还说几点比较零散的:
1、对于短句(<30词),可以不进行输入翻转,模型收敛地稍微慢一点而已;对于长句则最好进行翻转
2、多阅读教程,多实践。上手操作永远是学习的最佳途径

微信扫码订阅
UP更新不错过~
关注
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Zsank

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值