tensorflow seq2seq模型 代码阅读分析

本文主要分析TensorFlow中的Seq2Seq模型,包括embedding_attention_seq2seq、embedding_attention_decoder和attention_decoder函数。在理解模型前,应熟悉output projection、attention计算公式及embedding原理。文中提到attention计算公式,并指出模型中cell的两次attention机制应用。建议先理解translate项目代码再进行阅读。
摘要由CSDN通过智能技术生成

如果刚开始入门该模型请阅读tf官方说明:Sequence-to-Sequence Models

模型应用于机器翻译的示例代码:github

如果还没有看懂tf的translate示例代码,请先理解透彻translate项目代码之后再阅读本文。

开始

开始阅读源码之前,应该对模型有基本的认识,了解模型的基本原理。我认为需要注意的几个关键点是:
1、output projection的作用
2、attention的计算公式
3、embedding的作用和原理

这里强调一下attention的计算公式:
这里写图片描述
另外,模型中一个cell的计算过程其实用到了两次attention机制,一次作用与cell的输入,一次作用与cell的输出。

embedding_attention_seq2seq函数

def embedding_attention_seq2seq(encoder_inputs,
                                decoder_inputs,
                                cell,
                                num_encoder_symbols,
                                num_decoder_symbols,
                                embedding_size,
                                num_heads=1,
                                init_embedding=None,
                                output_projection=None,
                                feed_previous=False,
                                dtype=None,
                                scope=None,
                                initial_state_attention=False):
  '''对encoder输入进行embedding,运行encoder部分,将encoder输出作为参数传给embedding_attention_decoer'''

  with variable_scope.variable_scope(
      scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
    dtype = scope.dtype
    # 自己添加的代码,增加指定输入embedding的功能
    # embedding initializer
    if init_embedding:
        initializer = tf.constant_initializer(init_embedding,dtype=dtype)
    else:
        initializer = None
    # 将对输入的embedding添加到cell中
    # Encoder.
    encoder_cell = copy.deepcopy(cell)
    encoder_cell = core_rnn_cell.EmbeddingWrapper(
        encoder_cell,
        embedding_classes=num_encoder_symbols,
        embedding_size=embedding_size,
        initializer=initializer)
    # 运行encoder,得到输出和最终状态
    encoder_outputs, encoder_state = core_rnn.static_rnn(
        encoder_cell, encoder_inputs, dtype=dtype)

    # 这里对encoder_outputs 进行reshape,变为[batch_size,input_length,cell_size]大小
    # encoder_outputs将会在attention_decoder中用于attention的计算
    # First calculate a concatenation of encoder outputs to put attention on.
    top_states = [
        array_ops.reshape(e, [-1, 1, cell.output_size]) for e in encoder_outputs
    ]
    attention_states = array_ops.concat(top_states, 1)

    # 在cell中添加outputprojection
    # 这种output
  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值