用tensorflow实现seq2seq模型

本文详细介绍了如何使用TensorFlow实现Seq2Seq模型,包括数据预处理、模型创建、批次处理和检查点保存。通过电影剧本对话数据集进行预处理,使用结巴分词,移除低频词,并添加开始、结束和填充标签。模型采用bucketing处理不同长度的句子,调整网络参数如单元数目、层数和学习率。在训练过程中,利用tf的计算图和梯度更新进行模型优化。
摘要由CSDN通过智能技术生成

训练数据和预处理

数据集是电影剧本中的对话,我们首先需要做一些预处理以获得正确的数据格式。

  1. 切字分词
    使用结巴分词。
  2. 移除低频词
    代码中,用vocabulary_size 限制词表的大小。用UNK代替不包括在词表中的单词。例如,单词“非线性”不在词表中,则句子“非线性在神经网络中是重要的”变成“UNK在神经网络中是重要的”。
  3. 准备开始和结束标签以及填充标签
    在decoder端,GO表示解码开始,用EOS表示解码结束,同时用PAD表示填充。模型使用bucketing处理不同长度的句子。如果输入是3个tocken的英语句子,相应的输出是6个tocken的法语句子,则它们将被放入到[5,10]的bucket中。编码器将输入的长度将填充到5,解码器输入的长度将填充到10 ,填充标签是PAD。

创建模型

model = seq2seq_model.Seq2SeqModel(
      source_vocab_size=FLAGS.vocab_size,
      target_vocab_size=FLAGS.vocab_size,
      buckets=BUCKETS,
      size=FLAGS.size,
      num_layers=FLAGS.num_layers,
      max_gradient_norm=FLAGS.max_gradient_norm,
      batch_size=FLAGS.batch_size,
      learning_rate=FLAGS.learning_rate,
      learning_rate_decay_factor=FLAGS.learning_rate_decay_factor,
      use_lstm=False,
      forward_only=forward_only)
  • source_vocab_size是源输入词表的大小,
  • target_vocab_size是目标输出词表的大小。
  • Bucketing是一种有效处理不同长度的句子的方法。例如将英语翻译成法语时,输入具有不同长度的英语句子L1,输出是具有不同长度的法语句子L2,原则上应该为每一对(L1,L2 + 1)创建一个seq2seq模型。这会导致图很大,包括许多非常相似的子图。另一方面,我们可以用一个特殊的_PAD符号填充每个句子。然后,只需要一个seq2seq模型。但是对于较短的句子,要编码和解码许多无用的PAD符号,这样的模型也是低效的。作为折中,使用多个buckets 并且将每个句子填充为对应的bucket的长度。在config.py中,使用以下bucket。
BUCKETS = [(8, 10), (10, 15), (20, 25), (40, 50)]
  • 参数size代表网络中每一层的单元数目。
  • num_layers 代表网络的层数。
  • max_gradient_norm 表示梯度将被最大限度地削减到这个规范
  • batch_size 表示训练时的批处理大小。模型的构建与batch_size大小无关,所以在初始化之后它仍可以改变,e.g., 在decoding是时候可以改变batch_size
  • learning_rate 是初始的学习率。
  • learning_rate_decay_factor 学习率衰减因子,到了一定的阶段,学习率按照衰减因子进行衰减。
  • use_lstm 是一个布尔变量,表示是否使用lstm作为基本单元。true表示使用lstm,false表示使用gru。
  • num_samples 代表采样sampled softmax的数量。因为我们是使用softmax来处理输出的,如果输出词表很大时,计算效率会受到影响。因此当输出输出词表大
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值