如何创建一个pytorch的训练数据加载器(train_loader)用于批量加载训练数据

Talk is cheap,show me the code!  哈哈,先上几段常用的代码,以语义分割的DRIVE数据集加载为例:

DRIVE数据集的目录结构如下,下载链接DRIVE,如果官网下不了,到Kaggle官网可以下到:

1. 定义DriveDataset类,每行代码都加了注释,其中collate_fn()看不懂没关系:

import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset


class DriveDataset(Dataset):  # 继承Dataset类
    def __init__(self, root: str, train: bool, transforms=None):
        super(DriveDataset, self).__init__()
        self.flag = "training" if train else "test"  # 根据train这个布尔类型确定需要处理的是训练集还是测试集
        data_root = os.path.join(root, "DRIVE", self.flag)  # 得到数据集根目录
        assert os.path.exists(data_root), f"path '{data_root}' does not exists."   # 判断路径是否存在
        self.transforms = transforms   # 初始化图像变换操作
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]   # 遍历图像文件夹获取每个图像的文件名
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]  # 获取图像路径
        self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")  # 获取手动标签的路径
                       for i in img_names]
        # 检查手动标签文件是否存在
        for i in self.manual:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")
        # 获取分割的ROI区域掩码
        self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
                         for i in img_names]
        # check files
        for i in self.roi_mask:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")

    def __getitem__(self, idx):
        img = Image.open(self.img_list[idx]).convert('RGB')  # 加载图像,并转换为RGB模式
        manual = Image.open(self.manual[idx]).convert('L')   # 加载手动标注图像,并转换为灰度模式
        manual = np.array(manual) / 255   # 进行归一化操作
        roi_mask = Image.open(self.roi_mask[idx]).convert('L')  # 加载ROI图像,并转换为灰度模式
        roi_mask = 255 - np.array(roi_mask)   # 对图像数组取反,使用这个方法将背景和前景颜色反转,白色是255,黑色是0,反转后ROI变成了内黑外白
        mask = np.clip(manual + roi_mask, a_min=0, a_max=255)   # 将手动标注图像和反转后的ROI图像相加,使用np.clip()将像素值控制在0-255范围,

        # 这里转回PIL的原因是,transforms中是对PIL数据进行处理
        mask = Image.fromarray(mask)

        if self.transforms is not None:
            img, mask = self.transforms(img, mask)

        return img, mask

    # 获取图像数据集长度
    def __len__(self):
        return len(self.img_list)

    # 用于将批量的图像和标签数据合并为一个批张量。
    @staticmethod
    def collate_fn(batch):
        images, targets = list(zip(*batch))  # 将批量数据拆分为图像和标签两个列表
        batched_imgs = cat_list(images, fill_value=0)  # 使用 cat_list() 函数将图像和标签列表合并成张量。用于将列表中的 PIL 图像数据堆叠成张量,
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets


def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))  # 找到图像中最大的形状,以元组形式返回给max_size
    batch_shape = (len(images),) + max_size   # 计算出堆叠后的张量形状,包括批量大小和图像大小两个维度
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)  # 创建一个新的空白张量 batched_imgs,其形状与 batch_shape 相同,并将其填充为指定的填充值 fill_value
    for img, pad_img in zip(images, batched_imgs):  # 使用 zip() 函数将输入列表中的每个图像与其对应的空白张量进行拼接,以得到一个完整的张量。
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)  # 将每个图像按照其实际大小插入到空白张量的左上角,以保持图像的相对位置不变。
    return batched_imgs

2. 构建训练集和验证集对象。调用上述自定义的DriveDataset数据集类,通过传入不同的参数来区分训练集和验证集。arg.data_path表示数据集所在的路径,transforms参数则是表示对数据进行预处理的操作,包括图像增强和归一化等,mean和std是规范化处理时用到的均值和标准差。get_transform这个类下面会介绍。train=True表示构建训练集对象,train=False表示构建验证集对象。

train_dataset = DriveDataset(args.data_path,
                                 train=True,
                                 transforms=get_transform(train=True, mean=mean, std=std))
val_dataset = DriveDataset(args.data_path,
                               train=False,
                               transforms=get_transform(train=False, mean=mean, std=std))

3.这是定义图像预处理方式,包括训练集和测试集的图像和标签的预处理方式,每行代码的具体作用注释有介绍。

# 定义训练集图像的预处理方式
class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
                 mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        min_size = int(0.5 * base_size)
        max_size = int(1.2 * base_size)

        trans = [T.RandomResize(min_size, max_size)]   # 对图像的短边(长和宽中最短的)进行随机缩放以适应不同图像输入尺寸,缩放范围为【min_size, max_size】
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))  # 加入水平翻转
        if vflip_prob > 0:
            trans.append(T.RandomVerticalFlip(vflip_prob))  # 加入垂直翻转
        trans.extend([
            T.RandomCrop(crop_size),   # 对图像进行随机裁剪
            T.ToTensor(),  # 将数组矩阵转换为tensor类型,规范化到【0,1】范围
            T.Normalize(mean=mean, std=std),  # 加入图像归一化,并定义均值和标准差,RGB三通道的
        ])
        # trans是一个列表类型,包含各种了变换,将这些变换组成一个compose变换,注意transforms.Compose()函数需要接收一个列表类型
        self.transforms = T.Compose(trans)

    # 使用__call__()函数来调用transforms变换
    def __call__(self, img, target):
        return self.transforms(img, target)  # target是指标签图像,img是指待分割图像


# 定义验证集的图像预处理组合类,比较简单,只有张量化和规范化两个操作,这里规范化使用的是ImageNet推荐的参数,注意这种做法是针对彩色图像
class SegmentationPresetEval:
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ])

    def __call__(self, img, target):
        return self.transforms(img, target)

# 定义一个函数根据数据集的类型来调用对应的数据集处理类
def get_transform(train, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    base_size = 565
    crop_size = 480
    # 检查train是否为True
    if train:
        return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
    else:
        return SegmentationPresetEval(mean=mean, std=std)

4. 定义训练时所使用的线程数目,这里如果时windows系统训练出错的话,建议把num_workers直接设置为0就可以解决。定义训练数据加载器train_loader,用于批量加载训练数据。在代码中,使用torch.utils.data.DataLoader类来创建数据加载器,构造函数的参数包括:

  • train_dataset:训练数据集,应该是一个符合 PyTorch Dataset 接口的对象。
  • batch_size:每个批次的样本数量。
  • num_workers:用于数据加载的线程数。
  • shuffle:是否在每个时期(epoch)重新洗牌数据。一般只在训练集用
  • pin_memory:是否将数据加载到固定的内存区域,可以加速数据传输。
  • collate_fn:用于将样本列表转换为批次张量的函数。

num_workers = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])    # 如果batch_size>1, 线程数num_workers取min(cpu核数,batch_size)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=1,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             collate_fn=val_dataset.collate_fn)

至此数据集加载器创建完成!

  • 10
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Trouville01

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值