dataloader中 sampler、collate_fn 和 dataset 的 getitem使用理解。

Dataloader

DataLoader(dataset, sampler=None, collate_fn=None,
batch_size=1, shuffle=False, num_workers=0,
pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

这里先从dataset的运行机制讲起.

在dataloader按照batch进行取数据的时候, 是取出大小等同于batch size的index列表; 然后将列表列表中的index输入到dataset的getitem()函数中,取出该index对应的数据; 最后, 对每个index对应的数据进行堆叠, 就形成了一个batch的数据。

在最后一步堆叠的时候可能会出现问题: 如果一条数据中所含有的每个数据元的长度不同, 那么将无法进行堆叠. 如: multi-hot类型的数据, 序列数据.在使用这些数据时, 通常需要先进行长度上的补齐, 再进行堆叠. 以现在的流程, 是没有办法加入该操作的.
⭐️ 此外, 某些优化方法是要对一个batch的数据进行操作.

collate-fn函数就是手动将抽取出的样本堆叠起来的函数

所以就是:
先sampler获取要采样数据的一组index列表 ——> dataset的getitem函数用这组index获取一组数据,需要什么内容都在这里用index从准备好的数据中获取——> 然后用Collect_fn 对得到的这组数据进行一些处理,比如补齐,删减,批处理等,获得最终要用的数据及格式。

如何给collate-fn赋值?

info = args.info	# info是已经定义过的
loader = Dataloader(collate_fn=lambda x: collate_fn(x, info))

或者可以定义一个collactor类,实例化后传给collate_fn
在这个collator类中,需要有三个关键方法

  • def init(self, max_seq_len: int, tokenizer: BertTokenizer):
  • def call(self, examples: list) -> dict:
  • def pad_and_truncate(self, input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len):

其中初始化方法 __init__就是传递一些基本信息
然后__call__方法会从dataset中得到一batch的数据(examples),数据内容是dataset的getitem()方法的返回结果,可以对这个结果进行处理

        input_ids_list, token_type_ids_list, attention_mask_list, labels_list = list(zip(*examples))

将这些结果传入到定义的数据处理方法pad_and_truncate()中去,对这些结果进行补齐等操作。
该方法返回的最终结果就是训练过程中dataloader取一batch的数据。

class AFQMCDataset(Dataset):

    def __init__(self, data_dict: dict):
        super(AFQMCDataset, self).__init__()
        self.data_dict = data_dict

    def __getitem__(self, index: int) -> tuple:
        data = (self.data_dict['input_ids'][index], self.data_dict['token_type_ids'][index], self.data_dict['attention_mask'][index], self.data_dict['labels'][index])
        return data

    def __len__(self) -> int:
        return len(self.data_dict['input_ids'])


class Collator:
    def __init__(self, max_seq_len: int, tokenizer: BertTokenizer):
        self.max_seq_len = max_seq_len
        self.tokenizer = tokenizer

    def pad_and_truncate(self, input_ids_list, token_type_ids_list, attention_mask_list, labels_list, max_seq_len):
        input_ids = torch.zeros((len(input_ids_list), max_seq_len), dtype=torch.long)
        token_type_ids = torch.zeros_like(input_ids)
        attention_mask = torch.zeros_like(input_ids)
        for i in range(len(input_ids_list)):
            seq_len = len(input_ids_list[i])
            if seq_len <= max_seq_len:
                input_ids[i, :seq_len] = torch.tensor(input_ids_list[i], dtype=torch.long)
                token_type_ids[i, :seq_len] = torch.tensor(token_type_ids_list[i], dtype=torch.long)
                attention_mask[i, :seq_len] = torch.tensor(attention_mask_list[i], dtype=torch.long)
            else:
                input_ids[i] = torch.tensor(input_ids_list[i][:max_seq_len - 1] + [self.tokenizer.sep_token_id], dtype=torch.long)
                token_type_ids[i] = torch.tensor(token_type_ids_list[i][:max_seq_len], dtype=torch.long)
                attention_mask[i] = torch.tensor(attention_mask_list[i][:max_seq_len], dtype=torch.long)

        labels = torch.tensor(labels_list, dtype=torch.long)
        return input_ids, token_type_ids, attention_mask, labels

    def __call__(self, examples: list) -> dict:
        input_ids_list, token_type_ids_list, attention_mask_list, labels_list = list(zip(*examples))
        cur_max_seq_len = max(len(input_id) for input_id in input_ids_list)
        max_seq_len = min(cur_max_seq_len, self.max_seq_len)

        input_ids, token_type_ids, attention_mask, labels = self.pad_and_truncate(input_ids_list, token_type_ids_list,
                                                                                  attention_mask_list, labels_list,
                                                                                  max_seq_len)

        data_dict = {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

        return data_dict
        ```


动画参考:https://mp.weixin.qq.com/s/Uc2LYM6tIOY8KyxB7aQrOw
  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
下面是一个例子,展示如何用PyTorch自己写一个dataloader,并且dataloader集成自object对象。 1. 首先,需要导入PyTorch的DataLoaderDataset模块: ``` import torch from torch.utils.data import DataLoader, Dataset ``` 2. 接下来,定义一个自定义的Dataset类,继承自PyTorch的Dataset类,并实现__len__和__getitem__函数: ``` class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] ``` 其,__init__函数用于初始化数据集,__len__函数用于返回数据集的大小,__getitem__函数用于返回指定索引的数据。 3. 然后,定义一个自定义的DataLoader类,继承自PyTorch的DataLoader类,并实现__init__和__iter__函数: ``` class CustomDataLoader(DataLoader): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None): super().__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn, multiprocessing_context) def __iter__(self): for batch in super().__iter__(): yield self.transform(batch) def transform(self, batch): # 对批次数据进行变换,这里仅作为示例,不做实际变换 return batch ``` 其,__init__函数用于初始化DataLoader,__iter__函数用于循环获取批次数据,并在获取前对数据进行变换(这里仅作为示例,不做实际变换)。 4. 最后,调用CustomDataLoader即可获取一个dataloader: ``` data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] batch_size = 2 dataset = CustomDataset(data) dataloader = CustomDataLoader(dataset, batch_size=batch_size, shuffle=True) for batch in dataloader: print(batch) ``` 这样就可以得到一个dataloader了。在本例,数据集是一个简单的数字列表,每个批次包含两个数字,dataloader会将数据集分成多个批次,每次输出一个批次的数据。自定义的DataLoader类继承自PyTorch的DataLoader类,并覆盖了__init__和__iter__函数,以实现自定义的功能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值