Dataloader
DataLoader(dataset, sampler=None, collate_fn=None,
batch_size=1, shuffle=False, num_workers=0,
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
这里先从dataset的运行机制讲起.
在dataloader按照batch进行取数据的时候, 是取出大小等同于batch size的index列表; 然后将列表列表中的index输入到dataset的getitem()函数中,取出该index对应的数据; 最后, 对每个index对应的数据进行堆叠, 就形成了一个batch的数据。
在最后一步堆叠的时候可能会出现问题: 如果一条数据中所含有的每个数据元的长度不同, 那么将无法进行堆叠. 如: multi-hot类型的数据, 序列数据.在使用这些数据时, 通常需要先进行长度上的补齐, 再进行堆叠. 以现在的流程, 是没有办法加入该操作的.
⭐️ 此外, 某些优化方法是要对一个batch的数据进行操作.
collate-fn函数就是手动将抽取出的样本堆叠起来的函数
所以就是:
先sampler获取要采样数据的一组index列表 ——> dataset的getitem函数用这组index获取一组数据,需要什么内容都在这里用index从准备好的数据中获取——> 然后用Collect_fn 对得到的这组数据进行一些处理,比如补齐,删减,批处理等,获得最终要用的数据及格式。
如何给collate-fn赋值?
info = args.info # info是已经定义过的
loader = Dataloader(collate_fn=lambda x: collate_fn(x, info))
或者可以定义一个collactor类,实例化后传给collate_fn
在这个collator类中,需要有三个关键方法
- def init(self, max_seq_len: int, tokenizer: BertTokenizer):
- def call(self, examples: list) -> dict:
- def pad_and_truncate(self, input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len):
其中初始化方法 __init__就是传递一些基本信息
然后__call__方法会从dataset中得到一batch的数据(examples),数据内容是dataset的getitem()方法的返回结果,可以对这个结果进行处理
input_ids_list, token_type_ids_list, attention_mask_list, labels_list = list(zip(*examples))
将这些结果传入到定义的数据处理方法pad_and_truncate()中去,对这些结果进行补齐等操作。
该方法返回的最终结果就是训练过程中dataloader取一batch的数据。
class AFQMCDataset(Dataset):
def __init__(self, data_dict: dict):
super(AFQMCDataset, self).__init__()
self.data_dict = data_dict
def __getitem__(self, index: int) -> tuple:
data = (self.data_dict['input_ids'][index], self.data_dict['token_type_ids'][index], self.data_dict['attention_mask'][index], self.data_dict['labels'][index])
return data
def __len__(self) -> int:
return len(self.data_dict['input_ids'])
class Collator:
def __init__(self, max_seq_len: int, tokenizer: BertTokenizer):
self.max_seq_len = max_seq_len
self.tokenizer = tokenizer
def pad_and_truncate(self, input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len):
input_ids = torch.zeros((len(input_ids_list), max_seq_len), dtype=torch.long)
token_type_ids = torch.zeros_like(input_ids)
attention_mask = torch.zeros_like(input_ids)
for i in range(len(input_ids_list)):
seq_len = len(input_ids_list[i])
if seq_len <= max_seq_len:
input_ids[i, :seq_len] = torch.tensor(input_ids_list[i], dtype=torch.long)
token_type_ids[i, :seq_len] = torch.tensor(token_type_ids_list[i], dtype=torch.long)
attention_mask[i, :seq_len] = torch.tensor(attention_mask_list[i], dtype=torch.long)
else:
input_ids[i] = torch.tensor(input_ids_list[i][:max_seq_len - 1] + [self.tokenizer.sep_token_id], dtype=torch.long)
token_type_ids[i] = torch.tensor(token_type_ids_list[i][:max_seq_len], dtype=torch.long)
attention_mask[i] = torch.tensor(attention_mask_list[i][:max_seq_len], dtype=torch.long)
labels = torch.tensor(labels_list, dtype=torch.long)
return input_ids, token_type_ids, attention_mask, labels
def __call__(self, examples: list) -> dict:
input_ids_list, token_type_ids_list, attention_mask_list, labels_list = list(zip(*examples))
cur_max_seq_len = max(len(input_id) for input_id in input_ids_list)
max_seq_len = min(cur_max_seq_len, self.max_seq_len)
input_ids, token_type_ids, attention_mask, labels = self.pad_and_truncate(input_ids_list, token_type_ids_list,
attention_mask_list, labels_list,
max_seq_len)
data_dict = {
'input_ids': input_ids,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'labels': labels
}
return data_dict
```
动画参考:https://mp.weixin.qq.com/s/Uc2LYM6tIOY8KyxB7aQrOw