pytorch——详解DataLoader中的sampler和collate_fn

最近在使用pytorch复现PointNet分割网络的过程中,在读入数据时遇到了一些问题,需要重写DataLoader中的sampler和collate_fn

Sampler

sampler的作用是按照指定的顺序向batch里面读入数据,自定义的sampler可以根据我们的需要返回索引,DataLoader会根据我们返回的索引值提取数据,生成batch

注意:
重写sampler需要重写__len__()和__iter__()方法,其中__len__()返回你读入数据的总长度,iter()返回一个迭代器

例如,我们需要sampler根据样本点数返回索引

class sampler(data.Sampler):
    """
    由于每个batch的点数可能不一致
    例如 len(b[0])=10220, len(b[1])=23300, len(b[2])=24000 , ...
    该sampler是为了将每个batch内的点数统一
    首先将batch里的样本按照点数从小到大排列
    返回排序之后的索引值
    """

    def __init__(self, data_source):
        super(sampler, self).__init__(data_source)
        self.x = data_source
        # y = data_source[1]
        self.lst = []
        for i in range(len(self.x)):
            self.lst.append(self.x[i].shape[0])
        self.idx = np.argsort(self.lst)  # 排序之后的索引

    def __iter__(self):
        return iter(self.idx)  # 这里的idx最后会返回给DataSets中的__getitem__方法

    def __len__(self):
        return len(self.x[0])  # 这里的__len__需要返回总长度

在data.DataLoader中找到了下面注释:

sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any Iterable with __len__
implemented. If specified, :attr:shuffle must not be specified.

意思是自己重写了sampler之后,shuffle关键字不能指定。

很重要的一点:
最后在传参的时候,传入的是sampler的示例
例如:

p_loader = data.DataLoader(p, batch_size, drop_last=True, sampler=sampler(sourcedata))  # p是DateSet的子类

collate_fn

collate_fn方法的作用是对于还未被连结batch进行操作,因为是没有被连结所以这里我说的batch是list类型
在这里插入图片描述

pytorch规定每一个batch中样本的点数必须相同,所以重写collate_fn方法,将每个batch中样本下采样到相同的数目
这里的函数的下采很简单,就是单纯的取得batch中的最小样本点数,将其他样本中的点shuffle之后取前最小个点数

def collate_fn(batch: list):
    """
    DataLoader中的最后一步
    对即将输出的batch进行操作
    这里将每一个batch的样本点数降到最少点数即min_num_pts
    :param batch: B*N*3 不同B中的N并不相同
    :return:
    """
    batch_size = len(batch)
    ret_cls = []
    for elem in batch:
        ret_cls.append(elem[2])
    ret_cls = torch.tensor(ret_cls)

    num_pts_lst = []
    for elem in batch:
        X, y = elem[0], elem[1]
        num_pts_lst.append(X.shape[0])
    sorted_lst = np.argsort(num_pts_lst)

    min_mun_pts = len(batch[sorted_lst[0]][0])

    temp_points = []
    temp_target = []
    for i in range(batch_size):
        X, y = batch[i][0], batch[i][1]
        torch.manual_seed(2021)
        X = X[torch.randperm(min_mun_pts)]
        torch.manual_seed(2021)
        y = y[torch.randperm(min_mun_pts)]
        temp_points.append(X)
        temp_target.append(y)
    ret_points = temp_points[0].unsqueeze(0)
    ret_target = temp_target[0].unsqueeze(0)

    for i in range(1, batch_size):
        cat_X = temp_points[i].unsqueeze(0)
        cat_y = temp_target[i].unsqueeze(0)
        ret_points = torch.cat([ret_points, cat_X], dim=0)
        ret_target = torch.cat([ret_target, cat_y], dim=0)
    return ret_points, ret_target, ret_cls

最后传入参数时,传入的是方法名

p_loader = data.DataLoader(p, batch_size, drop_last=True, collate_fn=collate_fn)

参考:
知乎大佬的sampler源码解读,很有用

  • 6
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值