Seq2Seq
RNN 网络结构
LSTM网络结构
机器翻译的历史
Seq2Seq的应用
Seq2Seq存在的问题
Attention机制
“高分辨率”聚焦在图片的某个特定区域并以“低分辨率”感知图像的周边区域的模式,通过大量实验证明,将attention机制应用在机器翻译,摘要生成,阅读理解等问题上,取得的成效显著
关注输入序列中某些状态下的内容
Bucket机制
正常情况要对所有句子进行补全,Bucket可以先分组,再计算
Seq2Seq对输入单词字母排序
Seq2Seq
本篇代码将实现一个基础版的Seq2Seq,输入一个单词(字母序列),模型将返回一个对字母排序后的“单词”。
基础Seq2Seq主要包含三部分:
• Encoder
• 隐层状态向量(连接Encoder和Decoder)
• Decoder
任务:
按字母顺序排序:hello --> ehllo
查看TensorFlow版本
http://www.lfd.uci.edu/~gohlke/pythonlibs/#tensorflow
from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow.python.layers.core import Dense
# Check TensorFlow Version
assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer'
print('TensorFlow Version: {}'.format(tf.__version__))
TensorFlow Version: 1.2.0
# 数据加载
import numpy as np
import time
import tensorflow as tf
with open('data/letters_source.txt', 'r', encoding='utf-8') as f:
source_data = f.read()
with open('data/letters_target.txt', 'r', encoding='utf-8') as f:
target_data = f.read()
# 数据预览
source_data.split('\n')[:10]
['bsaqq',
'npy',
'lbwuj',
'bqv',
'kial',
'tddam',
'edxpjpg',
'nspv',
'huloz',
'kmclq']
target_data.split('\n')[:10]
['abqqs',
'npy',
'bjluw',
'bqv',
'aikl',
'addmt',
'degjppx',
'npsv',
'hlouz',
'cklmq']
# # 数据预处理
def extract_character_vocab(data):
'''
构造映射表
'''
special_words = ['<PAD>', '<UNK>', '<GO>', '<EOS>']
set_words = list(set([character for line in data.split('\n') for character in line]))
# 这里要把四个特殊字符添加进词典
int_to_vocab = {
idx: word for idx, word in enumerate(special_words + set_words)}
vocab_to_int = {
word: idx for idx, word in int_to_vocab.items()}
return int_to_vocab, vocab_to_int
# In[6]:
# 构造映射表
source_int_to_letter, source_letter_to_int = extract_character_vocab(source_data)
target_int_to_letter, target_letter_to_int = extract_character_vocab(target_data)
# 对字母进行转换
source_int = [[source_letter_to_int.get(letter, source_letter_to_int['<UNK>'])
for letter in line] for line in source_data.split('\n')]
target_int = [[target_letter_to_int.get(letter, target_letter_to_int['<UNK>'])
for letter in line] + [target_letter_to_int['<EOS>']] for line in target_data.split('\n')]
# In[7]:
# 查看一下转换结果
source_int[:10]
[[17, 9, 12, 11, 11],
[16, 29, 26],
[13, 17, 15, 25, 8],
[17, 11, 4],
[18, 10, 12, 13],
[23, 7, 7, 12, 24],
[27, 7, 6, 29, 8, 29, 5],
[16, 9, 29, 4],
[28, 25, 13, 21, 20],
[18, 24, 22, 13, 11]]
target_int[:10]
[[12, 17, 11, 11, 9, 3],
[16, 29, 26, 3],
[17, 8, 13, 25, 15, 3],
[17, 11, 4, 3],
[12, 10, 18, 13, 3],
[12, 7, 7, 24, 23, 3],
[7, 27, 5, 8, 29, 29, 6, 3],
[16, 29, 9, 4, 3],
[