使用fairseq从头开始训练一个中日神经机器翻译模型

本文参照使用fairseq从头开始训练一个中英神经机器翻译模型_宇日辰的博客-CSDN博客

最近使用fairseq进行中日神经机器翻译,奈何找了一圈都没找到中日的,参照大佬中英的,写一篇中日的,希望对后来人有所帮助,本人菜鸡,如有错误欢迎指正!

1、安装相关工具

相关工具的安装参考:使用fairseq从头开始训练一个中英神经机器翻译模型_宇日辰的博客-CSDN博客

数据集的问题,参考:Electronics | Free Full-Text | WCC-JC 2.0: A Web-Crawled and Manually Aligned Parallel Corpus for Japanese-Chinese Neural Machine Translation (mdpi.com)

文中有对应数据集

2、数据处理

2.1 整体架构

~
├── mosesdecoder
├── subword-nmt
├── fairseq
└── nmt
    ├── data
           ├── result          # 用于存放翻译结果
           └── data-bin        # 用于存放二进制文件
    ├── models                  # 用于保存过程中的model文件和checkpoint
           └── checkpoints     # 保存checkpoints
    ├── utils                   # 一些其他工具
        ├── cutdt.py            # 用于划分train,valid,test
        └── cut2.py             # 用于划分src,tgt
        
#cut2文件
import sys

'''
Usage:
python cut2.py fpath new_data_dir
'''

def cut2(fpath, nsrc='ja', ntgt='zh'):
    fp = open(fpath, encoding='utf-8')
    name = fpath.split('.txt')[0] + '.'
    src_fp = open(name + nsrc, 'w', encoding='utf-8')
    tgt_fp = open(name + ntgt, 'w', encoding='utf-8')
    for line in fp.readlines():
        src_line, tgt_line = line.replace('\n', '').split('\t')
        src_fp.write(src_line + '\n')
        tgt_fp.write(tgt_line + '\n')
    src_fp.close()
    tgt_fp.close()

if __name__ == '__main__':
    cut2(fpath=sys.argv[1], nsrc='ja', ntgt='zh')

#cutdt文件
import os
import random
import sys

def split_data(raw_data, new_data_dir, val_size=2000, test_size=2000):
    with open(raw_data, 'r', encoding='utf-8') as file:
        lines = file.readlines()

    random.shuffle(lines)

    val_pairs = []
    test_pairs = []
    train_pairs = []
    for line in lines:
        if len(line.split('\t')[0]) > 10 and len(line.split('\t')[1].strip()) > 10:
            if len(val_pairs) < val_size:
                val_pairs.append(line)
            elif len(test_pairs) < test_size:
                test_pairs.append(line)
            else:
                train_pairs.append(line)
        else:
            train_pairs.append(line)

    with open(new_data_dir + 'validc.txt', 'w', encoding='utf-8') as file:
        for pair in val_pairs:
            file.write(pair)

    with open(new_data_dir + 'testc.txt', 'w', encoding='utf-8') as file:
        for pair in test_pairs:
            file.write(pair)

    with open(new_data_dir + 'trainc.txt', 'w', encoding='utf-8') as file:
        for pair in train_pairs:
            file.write(pair)


if __name__ == '__main__':

    # 调用函数
    split_data(raw_data = sys.argv[1],new_data_dir=sys.argv[2])

2.2 路径脚本

#!/bin/sh

src=ja
tgt=zh

data_dir=~/nmt/data
model_dir=~/nmt/models
utils=~/nmt/utils

 2.3 切分中日

#注意此处选取测试和验证各2000条,且字符长度大于10
python ${utils}/cutdt.py ${data_dir}/pair.txt ${data_dir}/

#将三个数据集按中日分开
python ${utils}/cut2.py ${data_dir}/trainc.txt 
python ${utils}/cut2.py ${data_dir}/testc.txt 
python ${utils}/cut2.py ${data_dir}/validc.txt 

2.4 字符级分词

#本文中采用字符集别分词,而不使用mecab和jieba
#平均来说字符级别分词效果更好
cat trainc.ja | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > train.ja
cat trainc.zh | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > train.zh
cat testc.ja | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > test.ja
cat testc.zh | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > test.zh
cat validc.ja | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > valid.ja
cat validc.zh | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > valid.zh

3、训练

3.1 生成二进制文件

fairseq-preprocess --source-lang ${src} --target-lang ${tgt} \
    --trainpref ${data_dir}/train --validpref ${data_dir}/valid --testpref ${data_dir}/test \
    --destdir ${data_dir}/data-bin

3.2 开始训练

CUDA_VISIBLE_DEVICES=0 nohup fairseq-train ${data_dir}/data-bin --arch transformer \
    --source-lang ${src} --target-lang ${tgt}  \
    --optimizer adam  --lr 5e-4 --adam-betas '(0.9, 0.98)' \
    --lr-scheduler inverse_sqrt --max-tokens 4096  --dropout 0.3 \
    --criterion label_smoothed_cross_entropy  --label-smoothing 0.1 \
    --max-epoch 30  --warmup-updates 4000 --warmup-init-lr '1e-07' \
    --keep-last-epochs 5 --num-workers 4 \
    --save-dir ${model_dir}/checkpoints --eval-bleu\
    > ${data_dir}/output.txt &

3.3 生成式解码

fairseq-generate ${data_dir}/data-bin \
    --path ${model_dir}/checkpoints/checkpoint_best.pt \
    --batch-size 128 --beam 8 > ${data_dir}/result/bestbeam8.txt

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值