seq2seq

该博客详细介绍了如何使用TensorFlow构建Seq2Seq模型,包括数据预处理、模型结构(编码器和解码器)、超参数设置、训练过程和预测。通过处理解码器输入,建立训练和推理解码器,最终连接编码器和解码器完成模型构建。
摘要由CSDN通过智能技术生成

Dataset

  • letters_source.txt:输入的字母序列列表,每一行代表一个序列
  • letters_target.txt: 训练的目标字母序列列表,每一行是对应着输入列表中的每一行(同行)
import numpy as np
import time

import helper

source_path = 'data/letters_source.txt'
target_path = 'data/letters_target.txt'

source_sentences = helper.load_data(source_path)
target_sentences = helper.load_data(target_path)

Preprocess

然后将字母转换成整数数字

def extract_character_vocab(data):
    #创建了用于填充的特殊字符列表
    special_words = ['<PAD>', '<UNK>', '<GO>',  '<EOS>']

    set_words = set([character for line in data.split('\n') for character in line])
    int_to_vocab = {
   word_i: word for word_i, word in enumerate(special_words + list(set_words))}
    vocab_to_int = {
   word: word_i for word_i, word in int_to_vocab.items()}

    return int_to_vocab, vocab_to_int

# 创建2个字典:int2letter 和 letter2int 
source_int_to_letter, source_letter_to_int = extract_character_vocab(source_sentences)
target_int_to_letter, target_letter_to_int = extract_character_vocab(target_sentences)

# 将字母转为整数,并在target后面添加了 '<EOS>'指示符。
#其中get(letter, source_letter_to_int['<UNK>'])是确保无法识别的字母用 '<UNK>'代替符表示
source_letter_ids = [[source_letter_to_int.get(letter, source_letter_to_int['<UNK>']) 
                      for letter in line] 
                     for line in source_sentences.split('\n')]
target_letter_ids = [[target_letter_to_int.get(letter, target_letter_to_int['<UNK>']) 
                      for letter in line] + [target_letter_to_int['<EOS>']] 
                     for line in target_sentences.split('\n')] 

print("Example source sequence")
print(source_letter_ids[:10])
print("\n")
print("Example target sequence")
print(target_letter_ids[:10])

Model

检查Tensorflow的版本

from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow.python.layers.core import Dense


assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer'
print('TensorFlow Version: {}'.format(tf.__version__))

Hyperparameters

epochs = 60
batch_size = 128
# RNN Size
rnn_size = 50

num_layers = 2
# Embedding Size
encoding_embedding_size = 15
decoding_embedding_size = 15

learning_rate = 0.001

Input

def get_model_inputs():
    input_data = tf.placeholder(tf.int32, [None, None], name='input')
    targets = tf.placeholder(tf.int32, [None, None], name='targets')
    lr = tf.placeholder(tf.float32, name='learning_rate')

    target_sequence_length = tf.placeholder(tf.int32, (None,),
                                            name='target_sequence_length')
    max_target_sequence_length = tf.reduce_max(target_sequence_length, 
                                               name='max_target_len')
    source_sequence_length = tf.placeholder(tf.int32, (None,), 
                                            name='source_sequence_length')
    
    return input_data, targets, lr
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值