目录
Machine Learning HW5
一、任务
机器翻译,英语翻译为中文
二、数据集
三、结果
因为已经过了提交截止时间,故未有相关分数
四、代码解析
库的引入和初始化
import sys
import pdb
import pprint
import logging
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import numpy as np
import tqdm.auto as tqdm
from pathlib import Path
from argparse import Namespace
from fairseq import utils
import matplotlib.pyplot as plt
"""# Fix random seed"""
seed = 73
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
数据预处理
data_dir = './DATA/rawdata'
dataset_name = 'ted2020'
prefix = Path(data_dir).absolute() / dataset_name
prefix.mkdir(parents=True, exist_ok=True)
"""## Language"""
src_lang = 'en'
tgt_lang = 'zh'
data_prefix = f'{prefix}/train_dev.raw'
test_prefix = f'{prefix}/test.raw'
"""## Preprocess files"""
import re
def strQ2B(ustring):
"""Full width -> half width"""
# reference:https://ithelp.ithome.com.tw/articles/10233122
ss = []
for s in ustring:
rstring = ""
for uchar in s:
inside_code = ord(uchar)
if inside_code == 12288: # Full width space: direct conversion
inside_code = 32
elif (inside_code >= 65281 and inside_code <= 65374): # Full width chars (except space) conversion
inside_code -= 65248
rstring += chr(inside_code)
ss.append(rstring)
return ''.join(ss)
def clean_s(s, lang):
if lang == 'en':
s = re.sub(r"\([^()]*\)", "", s) # remove ([text])
s = s.replace('-', '') # remove '-'
s = re.sub('([.,;!?()\"])', r' \1 ', s) # keep punctuation
elif lang == 'zh':
s = strQ2B(s) # Q2B
s = re.sub(r"\([^()]*\)", "", s) # remove ([text])
s = s.replace(' ', '')
s = s.replace('—', '')
s = s.replace('“', '"')
s = s.replace('”', '"')
s = s.replace('_', '')
s = re.sub('([。,;!?()\"~「」])', r' \1 ', s) # keep punctuation
s = ' '.join(s.strip().split())
return s
def len_s(s, lang):
if lang == 'zh':
return len(s)
return len(s.split())
def clean_corpus(prefix, l1, l2, ratio=9, max_len=1000, min_len=1):
if Path(f'{prefix}.clean.{l1}').exists() and Path(f'{prefix}.clean.{l2}').exists():
print(f'{prefix}.clean.{l1} & {l2} exists. skipping clean.')
return
with open(f'{prefix}.{l1}', 'r') as l1_in_f:
with open(f'{prefix}.{l2}', 'r') as l2_in_f:
with open(f'{prefix}.clean.{l1}', 'w') as l1_out_f:
with open(f'{prefix}.clean.{l2}', 'w') as l2_out_f:
for s1 in l1_in_f:
s1 = s1.strip()
s2 = l2_in_f.readline().strip()
s1 = clean_s(s1, l1)
s2 = clean_s(s2, l2)
s1_len = len_s(s1, l1)
s2_len = len_s(s2, l2)
if min_len > 0: # remove short sentence
if s1_len < min_len or s2_len < min_len:
continue
if max_len > 0: # remove long sentence
if s1_len > max_len or s2_len > max_len:
continue
if ratio > 0: # remo