本文参照使用fairseq从头开始训练一个中英神经机器翻译模型_宇日辰的博客-CSDN博客
最近使用fairseq进行中日神经机器翻译,奈何找了一圈都没找到中日的,参照大佬中英的,写一篇中日的,希望对后来人有所帮助,本人菜鸡,如有错误欢迎指正!
1、安装相关工具
相关工具的安装参考:使用fairseq从头开始训练一个中英神经机器翻译模型_宇日辰的博客-CSDN博客
文中有对应数据集
2、数据处理
2.1 整体架构
~
├── mosesdecoder
├── subword-nmt
├── fairseq
└── nmt
├── data
├── result # 用于存放翻译结果
└── data-bin # 用于存放二进制文件
├── models # 用于保存过程中的model文件和checkpoint
└── checkpoints # 保存checkpoints
├── utils # 一些其他工具
├── cutdt.py # 用于划分train,valid,test
└── cut2.py # 用于划分src,tgt
#cut2文件
import sys
'''
Usage:
python cut2.py fpath new_data_dir
'''
def cut2(fpath, nsrc='ja', ntgt='zh'):
fp = open(fpath, encoding='utf-8')
name = fpath.split('.txt')[0] + '.'
src_fp = open(name + nsrc, 'w', encoding='utf-8')
tgt_fp = open(name + ntgt, 'w', encoding='utf-8')
for line in fp.readlines():
src_line, tgt_line = line.replace('\n', '').split('\t')
src_fp.write(src_line + '\n')
tgt_fp.write(tgt_line + '\n')
src_fp.close()
tgt_fp.close()
if __name__ == '__main__':
cut2(fpath=sys.argv[1], nsrc='ja', ntgt='zh')
#cutdt文件
import os
import random
import sys
def split_data(raw_data, new_data_dir, val_size=2000, test_size=2000):
with open(raw_data, 'r', encoding='utf-8') as file:
lines = file.readlines()
random.shuffle(lines)
val_pairs = []
test_pairs = []
train_pairs = []
for line in lines:
if len(line.split('\t')[0]) > 10 and len(line.split('\t')[1].strip()) > 10:
if len(val_pairs) < val_size:
val_pairs.append(line)
elif len(test_pairs) < test_size:
test_pairs.append(line)
else:
train_pairs.append(line)
else:
train_pairs.append(line)
with open(new_data_dir + 'validc.txt', 'w', encoding='utf-8') as file:
for pair in val_pairs:
file.write(pair)
with open(new_data_dir + 'testc.txt', 'w', encoding='utf-8') as file:
for pair in test_pairs:
file.write(pair)
with open(new_data_dir + 'trainc.txt', 'w', encoding='utf-8') as file:
for pair in train_pairs:
file.write(pair)
if __name__ == '__main__':
# 调用函数
split_data(raw_data = sys.argv[1],new_data_dir=sys.argv[2])
2.2 路径脚本
#!/bin/sh
src=ja
tgt=zh
data_dir=~/nmt/data
model_dir=~/nmt/models
utils=~/nmt/utils
2.3 切分中日
#注意此处选取测试和验证各2000条,且字符长度大于10
python ${utils}/cutdt.py ${data_dir}/pair.txt ${data_dir}/
#将三个数据集按中日分开
python ${utils}/cut2.py ${data_dir}/trainc.txt
python ${utils}/cut2.py ${data_dir}/testc.txt
python ${utils}/cut2.py ${data_dir}/validc.txt
2.4 字符级分词
#本文中采用字符集别分词,而不使用mecab和jieba
#平均来说字符级别分词效果更好
cat trainc.ja | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > train.ja
cat trainc.zh | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > train.zh
cat testc.ja | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > test.ja
cat testc.zh | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > test.zh
cat validc.ja | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > valid.ja
cat validc.zh | sed 's/ //g' | awk '{gsub(/./,"& "); print}' > valid.zh
3、训练
3.1 生成二进制文件
fairseq-preprocess --source-lang ${src} --target-lang ${tgt} \
--trainpref ${data_dir}/train --validpref ${data_dir}/valid --testpref ${data_dir}/test \
--destdir ${data_dir}/data-bin
3.2 开始训练
CUDA_VISIBLE_DEVICES=0 nohup fairseq-train ${data_dir}/data-bin --arch transformer \
--source-lang ${src} --target-lang ${tgt} \
--optimizer adam --lr 5e-4 --adam-betas '(0.9, 0.98)' \
--lr-scheduler inverse_sqrt --max-tokens 4096 --dropout 0.3 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-epoch 30 --warmup-updates 4000 --warmup-init-lr '1e-07' \
--keep-last-epochs 5 --num-workers 4 \
--save-dir ${model_dir}/checkpoints --eval-bleu\
> ${data_dir}/output.txt &
3.3 生成式解码
fairseq-generate ${data_dir}/data-bin \
--path ${model_dir}/checkpoints/checkpoint_best.pt \
--batch-size 128 --beam 8 > ${data_dir}/result/bestbeam8.txt