dataset.py
统一将图像返回成torch能处理的[original_iamges.tensor,label.tensor]
torch.utils.data.DataLoader()
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
重点关注四个参数:
batch_size: 批处理数目
shuffle: 是否每个epoch都打乱
workers: 载入数据的线程数
dataset: 是经过变换的自己的数据集(即:一个继承了torch.utils.data.Dataset类的子类的实例
---通常是自己需要什么处理什么---
通常在框架里面填写具体的东西:
是否transform如裁剪、归一化、旋转等?如果要transform则还需要区分test和train。比如train需要随机翻转,但是test则不需要操作.如何做到一张一张对应读取图片? 可以自定义这些函数。
必须要重载的是__getitem__()和__len__()。
__len__():len(dataset)返回数据集的大小。
__getitem__():实现数据集的下标索引,使用dataset[i]来得到第i个样本(图像和标记)。
--------------------------
import torch.utils.data as data
import torch
from torchvision import transforms
class MyTrainData(torch.utils.data.Dataset) #子类化
def __init__(self, root, transform=None, train=True): #第一步初始化各个变量
self.root = root
self.train = train
def __getitem__(self, idx): #第二步装载数据,返回[img,label],idx就是一张一张地读取
# get item 获取 数据
img = imread(img_path) #img_path根据自己的数据自定义,灵活性很高
img = torch.from_numpy(img).float() #需要转成float
gt = imread(gt_path) #读取gt,如果是分类问题,可以根据文件夹或命名赋值 0 1
gt = torch.from_numpy(gt).float()
return img, gt #返回 一一对应
def __len__(self):
return len(self.imagenumber) #这个是必须返回的长度