纪念小白完成的第一个机器学习项目,虽然绝大部分代码源自开源项目,只是做了一个数据集替换的工作
1.项目来源
源项目来自DevilExileSu/transformer: transformer,机器翻译,中文--英文 (github.com)
项目自带的数据集应该是爬虫获取的,翻译质量很低且都是政治领域话题。因此,为了完成作业,尝试更换数据集进行模型训练
2.数据集
iwslt2017,是一个用于机器翻译研究的公开数据集,特别是针对口语翻译任务。IWSLT是一个年度会议,每年都会发布针对该年度会议主题的数据集,用于促进口语翻译技术的发展和评估。
数据集下载代码如下:
from datasets import load_dataset
# 设置数据集保存路径
save_path = "/path/to/your/dataset"
dataset = load_dataset("iwslt2017", 'iwslt2017-zh-en')
dataset.save_to_disk(save_path)
3. 数据处理
为了继续使用原项目的dataloader类,对原项目中vocab部分做了一些修改:
import pickle
from tqdm import tqdm
from collections import Counter
from data.tokenize import Tokenizer
from datasets import load_from_disk
class Vocab(object):
def __init__(self, min_freq=10):
self.vocab = Counter()
self.min_freq = min_freq
self.word2id = None
self.id2word = None
self.vocab_size = None
def load(self, vocab_path):
with open(vocab_path, 'rb') as f:
tmp = pickle.load(f)
self.vocab = tmp['vocab']
self.min_freq = tmp['min_freq']
self.word2id = tmp['word2id']
self.id2word = tmp['id2word']
self.vocab_size = tmp['vocab_size']
def save(self, vocab_path):
vocab = {
"vocab": self.vocab,
"min_freq": self.min_freq,
"word2id": self.word2id,
"id2word": self.id2word,
"vocab_size": self.vocab_size
}
with open(vocab_path, 'wb') as f:
pickle.dump(vocab, f)
# 将数据集中句子分割为词并写入文件
def create(self, file_name, lang):
nlp = Tokenizer(lang)
lang = lang[:2]
print("-----------loading-----------")
data = load_from_disk(file_name)
lines = [pairs[lang] for pairs in data['translation']]
f = open(file_name + '.' + lang + '.token', 'w', encoding='utf-8')
for line in tqdm(lines):
token = nlp.tokenizer(line)
self.vocab.update(token)
l = '\t'.join(token)
f.write(l + '\n')
tmp = self.vocab.most_common()
tokens = ['<pad>', '<sos>', '<eos>', '<unk>']
tokens += [i[0] for i in tmp if i[1] > self.min_freq]
self.word2id = {word: idx for idx, word in enumerate(tokens)}
self.id2word = {idx: word for word, idx in self.word2id.items()}
self.vocab_size = len(self.word2id)
if __name__ == "__main__":
zh = 'zh_core_web_md'
en = 'en_core_web_md'
train = '../dataset/train'
valid = '../dataset/valid'
test = '../dataset/test'
zh_vocab = Vocab()
en_vocab = Vocab()
zh_vocab.create(train, zh)
print(zh_vocab.vocab_size)
zh_vocab.create(valid, zh)
print(zh_vocab.vocab_size)
zh_vocab.create(test, zh)
print(zh_vocab.vocab_size)
en_vocab.create(train, en)
print(en_vocab.vocab_size)
en_vocab.create(valid, en)
print(en_vocab.vocab_size)
en_vocab.create(test, en)
print(en_vocab.vocab_size)
zh_vocab.save('../dataset/zh_vocab.pkl')
en_vocab.save('../dataset/en_vocab.pkl')
4.其他
训练时发现,原项目的PositionEncoder中设置的一个句子max_seq_len=200,而iwslt数据集中最长句子为204个token,因此产生了错误,将max_seq_len设置大于204即可解决