pytorch实现数据导入网络模型

1整体介绍
在pytorch中要实现数据的训练,需要完成一下3步,创建Dataset对象,创建DataLoader迭代对象,遍历DataLoader对象
在这里插入图片描述
(1)Dataset对象
需要人为构建,这个类功能就是告诉pytorch你需要加载哪些数据,其有3个核心函数需要实现(其中getitem,len必须实现):
1:def init(self):数据在哪获取
2:def len(self):数据数量多少
在这里插入图片描述

3:def getitem(self, index):数据转成pytorch能处理的tensor格式
根据路径读取图像及其对应掩膜,图像及掩膜是否要进行transforms数据增强,返回成影像与掩膜数据
在这里插入图片描述

其中对于transforms数据增强需要我们人为去定义,目前常用的数据增强可以调用torchvision.transforms中提供的22种方式或者调用Albumentations库中提供的更为丰富的数据增强处理,下面以Albumentations中使用为例:然后在定义Dataset对象时直接调用即可。
在这里插入图片描述在这里插入图片描述
(2)DataLoader迭代对象
pytorch自带,该函数就是告诉pytorch怎么对读取进来的数据进行迭代遍历,该函数中常用的参数有:torch.utils.data.DataLoader(dataset,batch_size,shuffle(是否随机打乱),
num_workers(设置几个进程),collate_fn,drop_last(对于最后一组不整除的数据是否丢弃))
其中较难理解但常用的参数是collate_fn:该函数用于核对和整理数据,其输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果,输出是经过该函数整理后的结果。那我们什么时候会使用到这个参数呢?我们需要对数据进行进一步处理时。比如对于数据中可能产生被损害的照片,无法转成tensor,最优解当然是剔除或改善该照片,但有时候无法完全找到到底哪一张损坏,便可以考虑自定义collate_fn函数,通过过滤掉矩阵为None的照片,保证迭代不会存在问题batch = list(filter(lambda x:x[0] is not None, batch);或者我们想对读取的照片进行数据处理,比如都所有的像素都加10,也可以在这里进行自定义,然后返回每个像素增加10后的结果。即总的来说该函数使得我们的数据在进行迭代前有了在一次重新整理的机会,通常自定义格式如下:最后通常要进过pytorch官方的default_collate函数,其有很多功能,如将numpy转成可以遍历的tensor格式,确保我们迭代不会出错。
在这里插入图片描述
(3)遍历DataLoader对象
完成了上面两个类的定义后,便进入的训练的迭代过程
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值