【无标题】利用self.collate_fn自定义DataLoader

自定义DataLoader

这是为了补上次讲自定义Dataset挖的坑

连接如果想看前文 链接: 自定义Dataset


前情提要

由于DataLoader的进行是需要在Dataset的基础上,这是Dataset的基本结构

class MyDataset(Dataset):
    def __init__(self, datas, label_list):
      self.datas = datas
      self.labels = []

      self.label_list = label_list
      keys = list(set([y for y in self.label_list]))
      keys.sort()
      dictkeys = {key: ii for ii, key in enumerate(keys)}
      print(dictkeys)

      for i in range(len(self.label_list)):
        self.labels.append(dictkeys[label_list[i]])

    def __len__(self):
      return len(self.labels)

    def __getitem__(self, idx):
      return torch.FloatTensor(self.datas[idx]), torch.tensor(self.labels[idx])

数据准备,data的shape为(26, 10),第一行全为1,第二行全为2,一直到最后一行全为26。label为26个字母

import numpy as np
# 创建一个 shape 为 (24, 10) 的数组
array_shape = (24, 10)

# 使用 np.arange() 函数生成从 1 到 24 的数组
data_array = np.arange(1, array_shape[0] + 1)

# 使用 np.tile() 函数将每个元素复制 10 次,构成行
result_array = np.tile(data_array, (array_shape[1], 1)).T

print(result_array)

# 使用列表推导式生成从 A 到 Z 的字母列表
alphabet_list = [chr(i) for i in range(ord('A'), ord('Z') + 1)]

print(alphabet_list)

到重点了

其实就是继承DataLoader类,然后将自己写的方法替代其中的self.collate_fn方法

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader

def _collate_fn(batch):
    datas = []
    labels = []
    for i in range(len(batch)):
        datas.append(batch[i][0])
        labels.append(batch[i][1])
    datas = torch.stack(datas, dim=0)
    labels = torch.stack(labels)
    return datas, labels

class MyDataLoader(DataLoader):
        def __init__(self, *args, **kwargs):
                super(MyDataLoader, self).__init__(*args, **kwargs)
                self.collate_fn = _collate_fn
        

至于self.collate_fn起什么作用,我们都知道DataLoader中有一个参数batch_size,假如我们将batch_size设置为4,那么你可以理解为,调用Dataset中的__getitem__(self, idx)4次

然后将__getitem__(self, idx)得到的返回值设为元组,再将四个元组添加到一个list中,也就变成我们所说的一个batch, 这个batch便会作为参数传给self.collate_fn函数

其中还有一个参数shuffle,如果设置为True,那么就会随机的设置__getitem__(self, idx)的idx, 设置为False,按顺序取idx

最后就可以使用了

train_dataset = MyDataset(result_array, alphabet_list)
train_dataloader = MyDataLoader(train_dataset, batch_size=4, shuffle=False)

for idx, (data, label) in enumerate(train_dataloader):
    print(label)

  • 21
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
class DistributedSampler(_DistributedSampler): def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): super().__init__(dataset, num_replicas=num_replicas, rank=rank) self.shuffle = shuffle def __iter__(self): if self.shuffle: g = torch.Generator() g.manual_seed(self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() else: indices = torch.arange(len(self.dataset)).tolist() indices += indices[:(self.total_size - len(indices))] assert len(indices) == self.total_size indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples return iter(indices) def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4, seed=None, logger=None, training=True, merge_all_iters_to_one_epoch=False, total_epochs=0): dataset = __all__[dataset_cfg.DATASET]( dataset_cfg=dataset_cfg, class_names=class_names, root_path=root_path, training=training, logger=logger, ) if merge_all_iters_to_one_epoch: assert hasattr(dataset, 'merge_all_iters_to_one_epoch') dataset.merge_all_iters_to_one_epoch(merge=True, epochs=total_epochs) if dist: if training: sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: rank, world_size = common_utils.get_dist_info() sampler = DistributedSampler(dataset, world_size, rank, shuffle=False) else: sampler = None dataloader = DataLoader( dataset, batch_size=batch_size, pin_memory=True, num_workers=workers, shuffle=(sampler is None) and training, collate_fn=dataset.collate_batch, drop_last=False, sampler=sampler, timeout=0, worker_init_fn=partial(common_utils.worker_init_fn, seed=seed) ) return dataset, dataloader, sampler
最新发布
07-12
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值