pytorch模型构建(一)——datasets部分之dataloader

一、简介

1. torch中的dataloader:

torch.utils.data.DataLoader(dataset,
						 batch_size=1, 
						 shuffle=False, 
						 sampler=None,
						 batch_sampler=None, 
						 num_workers=0, 
						 collate_fn=None,
						 pin_memory=False, 
						 drop_last=False, 
						 timeout=0,
						 worker_init_fn=None, *, 
						 prefetch_factor=2,
						 persistent_workers=False)

2. 参数解释:

  • dataset: data和labels,一般继承torch.utils.data.Dataset,可在其中做数据增广等处理。
  • batch_size: default=1
  • shuffle: default=False,是否打乱数据集。
  • sampler: 定义采样策略,若自己定义,则shuffle要设定False。如使用难采样三元组损失时,就需要在一个batch内对当前样本进行自定义的采样规则。
  • batch_sampler:sampler类似,但是返回的是一个batch的index,与shuffle, sampler, drop_last互斥
  • num_workers: default=0,注意windows中使用时候一般设置为0,要不然会出错或者速度更慢
  • collate_fn: 将一个batch内的imgs和labels合并,如果只返回img和label,那么可以使用默认的collate_fn,但是如果返回img box label,每一个img的box数目不一定相同,所以就需要在这个函数里面加入当前box属于当前batch的哪一张图片,就需要自定义collate_fn将对应的数据合并成一个batch。再如,使用mosaic数据增强时,需要将处理好的4张图片拼接起来,同样对应的label也要拼接。
  • pin_memory: 内存大就开,不大就不开。表示要将load进来的数据是否要拷贝到pin_memory区中,其表示生成的Tensor数据是属于内存中的锁页内存区,这样将Tensor数据转义到GPU中速度就会快一些,默认为False。通常情况下,数据在内存中要么以锁页的方式存在,要么保存在虚拟内存(磁盘)中,设置为True后,数据直接保存在锁页内存中,后续直接传入cuda;否则需要先从虚拟内存中传入锁页内存中,再传入cuda,这样就比较耗时了,但是对于内存的大小要求比较高。
  • drop_last: 最后的数据不够一个batch_size时,是否选择舍去。默认为False
  • timeout: default=0,一般不用管
  • worker_init_fn: 一般不用管
  • prefetch_factor: 一般不用管
  • persistent_workers: 一般不用管

3. 思路:
常见的需要自己自定义按照自己需要重写参数有datasetsamplercollate_fn

二、构造过程

1. 思路:
先将读取的dataset函数进行重写(包括数据增强、矩形推理等都在这个部分,要注意当图像坐标发生变化时对应的label box的坐标也要对应的进行处理,最好将输出后的img和box坐标画出来,看看是否处理错误),以及采样规则sampler,batch中的样本结合处理。

1. 构建dataset

from torch.utils.data import Dataset

class LoadImgsAndLabels(Dataset):
    def __init__(self, path, img_szie=640, batch_size=16, argument=False, hyp=None):
        '''
            一般情况下,这里有图片路径(label路径可根据图片路径对应的改成自己的文件路径)
            img_size,图像一般裁剪到统一大小进行输入
            argument,是否使用数据增强
            hyp,处理过程中的超参列表,如数据增强中图像上下左右翻转的概率,有多大的概率进行mosaic等等
            还有其他需要的一些参数需要自定义
        '''

        # 1. 获得path下所有图片的绝对路径 self.img_files: []
        '''
            这里使用 try except Exception as e 判断是否加载数据错误之类的
            将path转换为 pathlib.Path 可以生成与os无关的分隔符,可以结合os.sep使用。
            用glob.glob和设置一个所有图片或者视频后缀列表进行判断, 筛选出所有的图片或者需要的文件。
            
        '''
        # 2. 根据获得的图片路径self.img_files 转换为 标签路径:self.label_files
        '''
            写个转换函数 img2label_paths
            self.label_files = img2label_paths(self.img_files)
        '''

    def __len__(self):
        '''
            返回当前数据集的长度(有多少张图片)
            return len(self.img_files)
        '''
        pass

    def __getitem__(self, index):
        '''
            这部分包含读取数据,数据增强, 一般一次性执行batch_size次
            可分为训练和测试
            eg:
                训练 数据增强: mosaic(random_perspective) + hsv + 上下左右翻转
                测试 数据增强: letterbox
        '''

        # 1. 读取图片和labels
        '''
            常规:使用index读取图片和labels
            数据增强:不同的数据增强可以写不同的读取图片的函数
                如:img, labels = load_mosaic(self, index)
                    img, (h0, w0), (h, w) = load_image(self, index) + letterbox + labels对应的处理
        '''
        pass

2. 构建sampler
遇到再写,一般在使用三元组损失之类的loss时候需要重写,比如图片匹配、行人重识别任务中使用。
或者是分布式采样:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)

3. 构建collate_fn

    def collate_fn(batch):
        """
        pytorch的DataLoader打包一个batch的数据集时要经过此函数进行打包 通过重写此函数实现标签与图片对应的划分,
        一个batch中哪些标签属于哪一张图片,形如
            [[0, 6, 0.5, 0.5, 0.26, 0.35],
             [0, 6, 0.5, 0.5, 0.26, 0.35],
             [1, 6, 0.5, 0.5, 0.26, 0.35],
             [2, 6, 0.5, 0.5, 0.26, 0.35],]
           前两行标签属于第一张图片, 第三行属于第二张。。。
        """
        img, label, path, shapes = zip(*batch)  # transposed
        for i, l in enumerate(label):
            l[:, 0] = i  # add target image index for build_targets()

        # 这里之所以拼接的方式不同是因为img拼接的时候它的每个部分的形状是相同的,都是[3, 736, 736]
        # 而label的每个部分的形状是不一定相同
        # 如果每张图的目标个数是相同的,那我们就可能不需要重写collate_fn函数了
        return torch.stack(img, 0), torch.cat(label, 0), path, shapes

4. 创建最后的dataloader

def create_dataloader(path, imgsz, batch_size, hyp=None, augment=False):
    dataset = LoadImagesAndLabels(path, imgsz, batch_size, augment=augment, hyp=hyp) 
    loader = torch.utils.data.DataLoader
    dataloader = loader(dataset,
                        batch_size=batch_size,
                        num_workers=nw,
                        sampler=sampler,
                        pin_memory=True,
                        collate_fn=LoadImagesAndLabels.collate_fn)
    return dataloader, dataset
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值