1、概述
本文是关于attention-seq2seq模型实现中文到英文的翻译。论文的核心gru作为seq2seq模型的基本单元。基本翻译效果如下图所示:
2、模型结构
整体模型结构如下图所示:
相关数学公式如下图所示
相关模型说明:
- FC = 全连接层(dense layer)
- EO = 编码器(encoder)输出
- H = 隐藏状态
- X =解码器(decoder)输入
相关公式的模型计算:
- score = FC(tanh(FC(EO) + FC(H)))
- attention weights = softmax(score, axis = 1)
- context vector = sum(attention weights * EO, axis = 1)
- embedding output = 通过一个嵌入层处理的译码器X的输入
- merged vector = concat(embedding output, context vector)
merged vector作为解码器的GRU的输入
3、具体实现
3.1引入相关类库
import tensorflow as tf
tf.enable_eager_execution()
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import re
import numpy as np
import os
import time
import jieba
3.2语料处理
设定语料的相关路径,请参考如下代码:
# 指定语料所在的路径
path_to_file = "cmn1.txt"
# 设定需要用到的数据集的条数,该数据及一共有20294条数据,我们将其中20000条数据作为训练数据
num_examples = 20000
语料处理相关函数,请参考如下代码:
# 预先对句子进行处理
def preprocess_sentence(w):
w = ' '.join(jieba.cut(w))
w = '<start> ' + w + ' <end>'
# 将多个空格转化为一个空格
w = re.sub(r'[" "]+', " ", w)
# 去除行首与行尾的空格
w = w.rstrip().strip()
# 使用结巴分词进行分词处理
return w
# 对句子进行预处理,返回[英文,中文]数据对
def create_dataset(path, num_examples):
lines = open(path, encoding='UTF-8').read().strip().split('\n')
word_pairs = [[preprocess_sentence(w) for w in l.split('\t')] for l in lines[:num_examples]]
return word_pairs
def max_length(tensor):
return max(len(t) for t in tensor)
class LanguageIndex():
def __init__(self, lang):
self.lang = lang
self.word2idx = {
}
self.idx2word = {
}
self.vocab = set()
self.create_index()
def create_index(self):
for phrase in self.lang:
self.vocab.update(phrase.split(' '))
self.vocab = sorted(self.vocab)
self.word2idx['<pad>'] = 0
for index, word in enumerate(self.vocab):
self.word2idx[word] = index + 1
for word, index in self.word2idx.items():
self.idx2word[index] = word
加载数据集的相关函数,请参考如下代码:
# 编写加载数据集函数 load_dataset
def load_dataset(path, num_examples):
# 清洗数据,创建输入输出对
pairs = create_dataset(path, num_examples)
inp_lang = LanguageIndex(sp for en, sp in pairs)
targ_lang = LanguageIndex(en for en, sp in pairs)
input_tensor = [[inp_lang.word2idx[s] for s in sp.split(