基于Transformer的本地机器翻译应用

纪念小白完成的第一个机器学习项目,虽然绝大部分代码源自开源项目,只是做了一个数据集替换的工作

1.项目来源

源项目来自DevilExileSu/transformer: transformer,机器翻译,中文--英文 (github.com)

项目自带的数据集应该是爬虫获取的,翻译质量很低且都是政治领域话题。因此,为了完成作业,尝试更换数据集进行模型训练

2.数据集

iwslt2017,是一个用于机器翻译研究的公开数据集,特别是针对口语翻译任务。IWSLT是一个年度会议,每年都会发布针对该年度会议主题的数据集,用于促进口语翻译技术的发展和评估。

数据集下载代码如下:

from datasets import load_dataset

# 设置数据集保存路径
save_path = "/path/to/your/dataset"

dataset = load_dataset("iwslt2017", 'iwslt2017-zh-en')
dataset.save_to_disk(save_path)

3. 数据处理

为了继续使用原项目的dataloader类,对原项目中vocab部分做了一些修改:

import pickle
from tqdm import tqdm
from collections import Counter
from data.tokenize import Tokenizer
from datasets import load_from_disk

class Vocab(object):
    def __init__(self, min_freq=10):
        self.vocab = Counter()
        self.min_freq = min_freq
        self.word2id = None
        self.id2word = None
        self.vocab_size = None

    def load(self, vocab_path):
        with open(vocab_path, 'rb') as f:
            tmp = pickle.load(f)
        self.vocab = tmp['vocab']
        self.min_freq = tmp['min_freq']
        self.word2id = tmp['word2id']
        self.id2word = tmp['id2word']
        self.vocab_size = tmp['vocab_size']

    def save(self, vocab_path):
        vocab = {
            "vocab": self.vocab,
            "min_freq": self.min_freq,
            "word2id": self.word2id,
            "id2word": self.id2word,
            "vocab_size": self.vocab_size
        }
        with open(vocab_path, 'wb') as f:
            pickle.dump(vocab, f)
    

    # 将数据集中句子分割为词并写入文件
    def create(self, file_name, lang):
        nlp = Tokenizer(lang)
        lang = lang[:2]

        print("-----------loading-----------")
        data = load_from_disk(file_name)
        lines = [pairs[lang] for pairs in data['translation']]
        f = open(file_name + '.' + lang + '.token', 'w', encoding='utf-8')
        for line in tqdm(lines):
            token = nlp.tokenizer(line)
            self.vocab.update(token)
            l = '\t'.join(token)
            f.write(l + '\n')

        tmp = self.vocab.most_common()
        tokens = ['<pad>', '<sos>', '<eos>', '<unk>']
        tokens += [i[0] for i in tmp if i[1] > self.min_freq]
        self.word2id = {word: idx for idx, word in enumerate(tokens)}
        self.id2word = {idx: word for word, idx in self.word2id.items()}
        self.vocab_size = len(self.word2id)


if __name__ == "__main__":
    zh = 'zh_core_web_md'
    en = 'en_core_web_md'
    train = '../dataset/train'

    valid = '../dataset/valid'

    test = '../dataset/test'

    zh_vocab = Vocab()
    en_vocab = Vocab()

    zh_vocab.create(train, zh)
    print(zh_vocab.vocab_size)
    zh_vocab.create(valid, zh)
    print(zh_vocab.vocab_size)
    zh_vocab.create(test, zh)
    print(zh_vocab.vocab_size)

    en_vocab.create(train, en)
    print(en_vocab.vocab_size)
    en_vocab.create(valid, en)
    print(en_vocab.vocab_size)
    en_vocab.create(test, en)
    print(en_vocab.vocab_size)

    zh_vocab.save('../dataset/zh_vocab.pkl')
    en_vocab.save('../dataset/en_vocab.pkl')




4.其他

训练时发现,原项目的PositionEncoder中设置的一个句子max_seq_len=200,而iwslt数据集中最长句子为204个token,因此产生了错误,将max_seq_len设置大于204即可解决

5.训练效果

  • 12
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值