2021SC@SDUSC
main函数:
if cfg.uda_mode: #只有uda才有无标签数据,需要算相对熵
unsup_criterion = nn.KLDivLoss(reduction='none')#kl散度即相对熵
data_iter = [data.sup_data_iter(), data.unsup_data_iter()] if cfg.mode=='train' \
else [data.sup_data_iter(), data.unsup_data_iter(), data.eval_data_iter()] # train_eval
uda.json,main主函数运行的参数之一
{
"seed": 42,
"lr": 2e-5,
"warmup": 0.1,
"do_lower_case": true,
"mode": "train_eval",
"uda_mode": true,
}
load_data.py
def sup_data_iter(self):
sup_dataset = self.TaskDataset(self.sup_data_dir, self.cfg.need_prepro, self.pipeline, self.cfg.max_seq_length, self.cfg.mode, 'sup')
sup_data_iter = DataLoader(sup_dataset, batch_size=self.sup_batch_size, shuffle=self.shuffle)
return sup_data_iter
self.TaskDataset返回IMDB
class IMDB定义如下:
class IMDB(CsvDataset):
labels = ('0', '1')
def __init__(self, file, need_prepro, pipeline=[], max_len=128, mode='train', d_type='sup'):
super().__init__(file, need_prepro, pipeline, max_len, mode, d_type)
def get_sup(self, lines):
for line in itertools.islice(lines, 0, None):
yield line[7], line[6], [] # label, text_a, None
# yield None, line[6], []
def get_unsup(self, lines):
for line in itertools.islice(lines, 0, None):
yield (None, line[1], []), (None, line[2], []) # ko, en
IMDB类继承CsvDataset类
class CsvDataset(Dataset):
labels = None
def __init__(self, file, need_prepro, pipeline, max_len, mode, d_type):
Dataset.__init__(self)
self.cnt = 0
# need preprocessing
if need_prepro:
with open(file, 'r', encoding='utf-8') as f:
lines = csv.reader(f, delimiter='\t', quotechar='"')
# supervised dataset
if d_type == 'sup':
# if mode == 'eval':
# sentences = []
data = []
for instance in self.get_sup(lines):
# if mode == 'eval':
# sentences.append([instance[1]])
for proc in pipeline:
instance = proc(instance, d_type)
data.append(instance)
self.tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*data)]
# if mode == 'eval':
# self.tensors.append(sentences)
# unsupervised dataset
elif d_type == 'unsup':
data = {'ori':[], 'aug':[]}
for ori, aug in self.get_unsup(lines):
for proc in pipeline:
ori = proc(ori, d_type)
aug = proc(aug, d_type)
self.cnt += 1
# if self.cnt == 10:
# break
data['ori'].append(ori) # drop label_id
data['aug'].append(aug) # drop label_id
ori_tensor = [torch.tensor(x, dtype=torch.long) for x in zip(*data['ori'])]
aug_tensor = [torch.tensor(x, dtype=torch.long) for x in zip(*data['aug'])]
self.tensors = ori_tensor + aug_tensor
# already preprocessed
file参数传入的为uda.json中保存的训练数据文本:
input_ids input_mask input_type_ids label_ids
[101, 1000, 1038, 1000, 1998, 1000, 1039, 1000, 5918, 1010, 2017, 2064, 6807, 1998, 5959, 3166, 1005, 1055, 7953, 1012, 1040, 1007, 2000, 3929, 5959, 1996, 3185, 2017, 2442, 2293, 2308, 2066, 11382, 12096, 2098, 24654, 3367, 1010, 2040, 2003, 2061, 3019, 1010, 4086, 1998, 27149, 1012, 1041, 1007, 2017, 2442, 19837, 5541, 1010, 1037, 2210, 11463, 2319, 9905, 10415, 2111, 2007, 2307, 1998, 9487, 12857, 2065, 2017, 3113, 2035, 2122, 5918, 2017, 1005, 2222, 2022, 3497, 2000, 3446, 2023, 3185, 2379, 2184, 2685, 1012, 1045, 2196, 4191, 2431, 1006, 999, 1007, 2004, 2172, 2004, 2013, 3666, 2023, 17743, 1012, 1998, 1045, 2130, 3266, 2000, 5390, 2096, 5870, 1999, 2070, 5312, 1006, 1045, 2467, 2131, 7591, 1010, 7188, 2204, 2477, 4148, 2000, 11382, 19020, 24654, 3367, 1007, 102] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0
csv.reader()函数:The returned object is an iterator. Each iteration returns a row of the CSV file
输出一个按行输出csv文件的迭代器。