目录
声明
文中出现的代码有些是部分的、表意用的,不能直接搬用,一个是直接上完整代码篇幅就太长了,另一个是我本人希望用“授人以渔”态度写文章。有不足的地方还请各位指正!
数据集
开始写这个项目时,我找的数据集来源于 https://wit3.fbk.eu/2015-01,是TED Talk的中英双语数据集。但是当时不知道效果为啥不好,就换了个数据集,是2017年的AI Challenge中用于翻译比赛的数据。下载地址:数据集,部分内容如下:
在这个数据集上的效果还行,从数据上看,可能是因为TED Talk的数据是人类自然用语,更复杂多变、有很多长句;但是AI Challenge的数据绝大部分是如上图的简单短句。因此建议读者在实践时选择简单的数据集、限制用于训练的最大句子长度。
数据处理
-
句子水平的处理
- 去除异种符号,包括但不限于数字、“#@”等奇怪的符号、其他语言字符,我在实践中还去除了标点符号,这点我也拿不准,仁者见仁智者见智吧;
sentence = re.sub(r"[a-zA-Z]", r"", sentence) if IsChinese(sentence) else sentence sentence = re.sub(r"[0-9]", r"", sentence) sentence = re.sub(r"[@#$%^-&*]", r"", sentence) sentence = re.sub(r"[,.!?';:,。!?’‘“”;:]", r"", sentence)
- 分词,中文我使用python库jieba,英文我使用python库nltk;
# word segmentation sentence = jieba.lcut(sentence) if IsChinese(sentence) else nltk.word_tokenize(sentence)
- 特殊处理,面对具体的语言和数据集会有不同的处理,还比如英文可以做小写化,比如在我用的数据集中文语句里还有空格隔开以及括号括住但是英文语句里没有的部分,这些我就去掉了;
# lowercase sentence = sentence.lower() # delete invalid characters sentence = re.sub(r"(.*)", r"", sentence) if IsChinese(sentence) else sentence sentence = re.sub(r" ", r"", sentence) if IsChinese(sentence) else sentence
- 加上开始和结束标志符,这个随意,可以用“bos”表示开始、“eos”表示结束,读者可以自行定义。我实际的处理是给源语言(source language)加开始符号、给目标语言(target language)加开始和结束符号;
for sent_index in range(len(source_data)): # This loop is for dropping long sentence and building vocabulary. if len(source_data[sent_index]) + 1 <= src_max_length and \ len(target_data[sent_index]) + 2 <= tgt_max_length: source.AddSentence(source_data[sent_index]) target.AddSentence(target_data[sent_index]) data.append((source_data[sent_index] + ["<eos>"], ["<bos>"] + target_data[sent_index] + ["<eos>"])) # Actually just add "<eos>" to source data and "<bos>" and "<eos>" both to target data.
- 补足句子,把短句子补长,这样才能放在一个batch里输入到模型,一般使用0填补且0不能作为词汇表中某个词的索引代表。有两种方式:一种全部补长到最大句子长度;一种把同一个batch里的句子补长到这个batch里的最大长度,暂且称之为“batch padding”。第二种可以节省存储、算力、时间,建议第二种。
def DoBatchPadding(self, data): dataset = [] for batch_index in range(len(data) // self.batch_size): # This loop decides how many batches in trian data. src_batch = [] tgt_batch = [] for sent_index in range(batch_index * self.batch_size,
- 去除异种符号,包括但不限于数字、“#@”等奇怪的符号、其他语言字符,我在实践中还去除了标点符号,这点我也拿不准,仁者见仁智者见智吧;