1.数据转换类(TranslateData)
import random
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import sys
sys.path.append('.')
from vocabField import VocabField
class TranslateData():
def __init__(self,pad = 0):
self.pad = pad
def collate_fn(self,batch):
src = list(map(lambda x:x['src'],batch))
tgt = list(map(lambda x:x['tgt'],batch))
src_len = list(map(lambda x:x['src_len'],batch))
tgt_len = list(map(lambda x:x['tgt_len'],batch))
src = torch.transpose(pad_sequence(src,padding_value = self.pad),0,1)
tgt = torch.transpose(pad_sequence(tgt,padding_value = self.pad),0,1)
src_len = torch.stack(src_len)
tgt_len = torch.stack(tgt_len)
return {'src':src,'tgt':tgt,'src_len':src_len,'tgt_len':tgt_len}
def translate_data(self,subs,obj):
import re
import unicodedata
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD',s) if unicodedata.category(c) != 'Mn'
)
def normalizeString(s):
s = unicodeToAscii(s.lower().strip())
s = re.sub(r'(.!?)',r'\1',s)
s = re.sub(r'[^a-zA-Z.!?]+',r' ',s)
return s
src,tgt = subs
src = normalizeString(src).split(' ')
tgt = normalizeString(tgt).split(' ')
tgt = [obj.tgt_vocab.sos_token] + tgt + [obj.tgt_vocab.eos_token]
if len(src) > obj.max_src_length or len(tgt) > obj.max_tgt_length:
return None
src_length,tgt_length = len(src),len(tgt)
src_ids = [obj.src_vocab.word2idx[w] for w in src]
tgt_ids = [obj.tgt_vocab.word2idx[w] for w in tgt]
return {
'src':torch.LongTensor(src_ids),
'tgt':torch.LongTensor(tgt_ids),
'src_len':torch.LongTensor([src_length]),
'tgt_len':torch.LongTensor([tgt_length])}
2.Dataset类(DialogDataset)
class DialogDataset(Dataset):
def __init__(self,data_fp,transform_fuc,src_vocab,tgt_vocab,max_src_length,max_tgt_length):
self.datasets = []
self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
loaded = 0
data_monitor = 0
with open(data_fp,'r') as f:
for line in tqdm(f,desc = 'Load Data:'):
subs = line.strip().split('\t')
loaded += 1
if not data_monitor:
data_monitor = len(subs)
else:
assert data_monitor == len(subs)
item = transform_fuc(subs,self)
if item:
self.datasets.append(item)
print(f"{loaded} paris loaded. {len(self.datasets)} are valid. Rate {1.0 * len(self.datasets)/loaded:.4f}")
def __len__(self):
return len(self.datasets)
def __getitem__(self,idx):
return self.datasets[idx]
3.测试
train_path = '../../data/fra2eng/fra_eng.dev'
dev_path = '../../data/fra2eng/fra_eng.dev'
src_vocab_file = '../../data/fra2eng/src_vocab_file'
tgt_vocab_file = '../../data/fra2eng/tgt_vocab_file'
src_vocab_size = 40000
tgt_vocab_size = 40000
max_src_length = 50
max_tgt_length = 50
batch_size = 20
src_vocab_list = VocabField.load_vocab(src_vocab_file)
tgt_vocab_list = VocabField.load_vocab(tgt_vocab_file)
src_vocab = VocabField(src_vocab_list,vocab_size = src_vocab_size)
tgt_vocab = VocabField(tgt_vocab_list,vocab_size = tgt_vocab_size)
pad_id = tgt_vocab.word2idx[tgt_vocab.pad_token]
trans_data = TranslateData()
train_set = DialogDataset(
train_path,
trans_data.translate_data,
src_vocab,
tgt_vocab,
max_src_length = max_src_length,
max_tgt_length = max_tgt_length
)
trainloader = DataLoader(
train_set,
batch_size = 20,
shuffle = False,
drop_last = True,
collate_fn = trans_data.collate_fn
)
dev_set = DialogDataset(
dev_path,
trans_data.translate_data,
src_vocab,
tgt_vocab,
max_src_length = max_src_length,
max_tgt_length = max_tgt_length
)
dev_loader = DataLoader(
dev_set,
batch_size = 15,
shuffle = False,
collate_fn = trans_data.collate_fn
)