【Pytorch】Load your own dataset

在这里插入图片描述

pytorch 在载入数据时用torchvision.datasets.ImageFolder 配合 torch.utils.data.DataLoader 很方便,但是只能遍历图片和图片的标签,无法灵活的获取图片的其他信息,比如图片的名字,本文介绍如何定义自己的 ImageFolder,在使用 Dataloader 时实现获取图片名字的功能!



1 ImageFolder and DataLoader

以分类为例,用 pytorch 的 torchvision.datasets.ImageFolder 配合 torch.utils.data.DataLoader 即可对数据按类别进行读取、预处理、分成 batch

import torchvision
import torch

train_dataset = torchvision.datasets.ImageFolder(
    train_data_pth,
    transforms.Compose([
        transforms.Resize(input_size,interpolation=2), # resize
        transforms.ToTensor(), # ToTensor
        normalize,])) # Normalization

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size, # set batchsize
    shuffle=False,
    num_workers=n_worker,
    pin_memory=True)

参考


ImageFoldertrain_data_pth 是存放数据集的文件夹,文件结构应该如下

train_data_pth
	class1
		xxx.jpg
		...
	class2
		xxx.jpg
		...
	...
	classn
		xxx.jpg
		...

Dataloader 的参数介绍如下

  • dataset:加载的数据集(Dataset对象)
  • batch_size:batch size
  • shuffle:是否将数据打乱
  • sampler: 样本抽样,后续会详细介绍
  • num_workers:使用多进程加载的进程数,0 代表不使用多进程
  • collate_fn: 如何将多个样本数据拼接成一个 batch,一般使用默认的拼接方式即可’
  • pin_memory:是否将数据保存在pin memory 区,pin memory 中的数据转到 GPU 会快一些
  • drop_last:dataset中的数据个数可能不是 batch_size 的整数倍,drop_last 为 True 会将多出来不足一个batch的数据丢弃

参考 pytorch之DataLoader()函数

官网中 Dataloader 的介绍如下(https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)

在这里插入图片描述
在这里插入图片描述

强调下,num_workers 表示一次可以装载 num_workers 个 batch,而不是一次装载一个 batch。

在训练和测试时,可用如下循环来对数据进行操作

for batch_images, batch_labels in train_loader:
	pass

把数据集的文件夹建立好,直接调用 ImageFolderDataLoader 来进行数据的载入分批读取确实很方便,但是如果我们想知道哪些图片分类错误了, train_loader loader 中仅有 image(图片) 和 label 属性,没有 image name(图片名称) 属性,有些力不从心!

因此,我们可以自己写 ImageFolder 来实现读取 image、label、image name 的功能,当然熟悉这个流程后,以后可以进行更个性化的操作!

参考 从零开始深度学习Pytorch笔记(11)—— DataLoader类


补充

collate_fn参数用于是否需要以自定义的方式组织一个batch, 例子中将一个mini-batch的数据组织成numpy.ndarray的类型. 默认情况下collate_fn=None时,数据以元组的方式返回.

比如将多个numpy小数组组合成一个大的numpy数组:

def ssd_dataset_collate(batch):
    # print('ssd_dataset_collate函数被执行...')
    images = []
    bboxes = []
    for img, box in batch:
        images.append(img)
        bboxes.append(box)
    images = np.array(images)
    bboxes = np.array(bboxes)
    return images, bboxes


gen = DataLoader(train_dataset, \
                        batch_size=Batch_size, \
                        num_workers=8, \
                        pin_memory=True,\
                        drop_last=True, \
                        collate_fn=ssd_dataset_collate)
                        # collate_fn=None)

来自 torch.utils.data.DataLoader中的collate_fn的使用

2 OwnFolder and DataLoader

自己写数据读取和预处理,来替代 torchvision.datasets.ImageFolder 的功能,具体实现如下 class Own_Dataset 所示

class Own_Dataset(Dataset):
    def __init__(self, image_label_list, transform=None):
        super().__init__()
        self.samples_list = image_label_list  # xxx.jpg class1 
        self.transform = transform  # pre-processing of data

    def __getitem__(self, index):
        img_name = self.samples_list[index][0] # absolute path of image name
        with open(img_name,"rb") as f:
            img = Image.open(f).convert("RGB") # load image
        label = self.samples_list[index][1] # image label
        if img is None:
            print(img_name)
        if self.transform is not None:
            img = self.transform(img)
        return img, label, img_name

    def __len__(self):
        return len(self.samples_list)

其中 image_label_list 为列表,存放着图片的绝对路径以及标签信息,格式如下

[(/train_data_pth/calss1/1.jpg,class1),
(/train_data_pth/calss1/2.jpg,class1),
...,
(/train_data_pth/calssn/m.jpg,classn)]

想实现更多功能,在 def __getitem__(self, index): 中定义即可,

__getitem__:实例[idx] 时触发

参考:【python】类(11)

配合 DataLoader 使用

train_loader = torch.utils.data.DataLoader(
    Own_Dataset(image_label_list=val_list,
               transform=transforms.Compose([
               	   transforms.Resize(input_size,interpolation=2), # resize
                   transforms.ToTensor(),
                   normalize,])),
    batch_size=test_batch_size,
    shuffle=False,
    num_workers=n_worker,
    pin_memory=True)

训练测试时,就可以访问图片,类别以及图片名信息了,如下所示

for batch_images, batch_labels,batch_names in train_loader:
	pass

3 transforms

下面介绍部分 torchvision.transforms 方法

更多的 torchvision.transforms 方法可以参考官网介绍

https://pytorch.org/docs/stable/torchvision/transforms.html

train_dataset = datasets.ImageFolder(
    train_data_pth,
    transforms.Compose([
        transforms.Resize(scale_size,interpolation=2),
        transforms.RandomRotation(5),
        transforms.ColorJitter(brightness=0.1,contrast=0.1,
                               saturation=0.1,hue=0.1),
        transforms.FiveCrop(input_size),
        transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
        transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(
            mean = [0.5,0.5,0.5],
            std = [0.5,0.5,0.5])(crop) for crop in crops]))
    ]))

input_size 和 scale_size 写成元组的形式,eg,(224,224) 和 (256,256)

Normalize 时注意 mean 和 std 一定要除以 255,值介于 0~1 之间

FiveCrop 或者 TenCrop 时,测试代码也需要进行相应的调整,如下

原来

out = net(batch_images)

现在

bs, ncrops, c, h, w, = batch_images.size()
result = net(batch_images.view(-1,c,h,w))
out = result.view(bs,ncrops,-1).mean(1)

计算数据集的均值和标准差

import os
import cv2
import numpy as np
from torch.utils.data import Dataset
from PIL import Image


def compute_mean_and_std(dataset):
    # 输入PyTorch的dataset,输出均值和标准差
    mean_r = 0
    mean_g = 0
    mean_b = 0

    for img, _ in dataset:
        img = np.asarray(img) # change PIL Image to numpy array
        mean_r += np.mean(img[:, :, 0])
        mean_g += np.mean(img[:, :, 1])
        mean_b += np.mean(img[:, :, 2])

    mean_r /= len(dataset)
    mean_g /= len(dataset)
    mean_b /= len(dataset)

    diff_r = 0
    diff_g = 0
    diff_b = 0

    N = 0

    for img, _ in dataset:
        img = np.asarray(img)

        diff_r += np.sum(np.power(img[:, :, 0] - mean_r, 2))
        diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2))
        diff_b += np.sum(np.power(img[:, :, 2] - mean_b, 2))

        N += np.prod(img[:, :, 0].shape)

    std_r = np.sqrt(diff_r / N)
    std_g = np.sqrt(diff_g / N)
    std_b = np.sqrt(diff_b / N)

    mean = (mean_r.item() / 255.0, mean_g.item() / 255.0, mean_b.item() / 255.0)
    std = (std_r.item() / 255.0, std_g.item() / 255.0, std_b.item() / 255.0)
    return mean, std
  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
首先,你需要将心电信号的数据存储在某种形式的数据文件中,例如CSV或HDF5文件。数据文件应该包含每个样本的心电信号,以及与每个样本相关的标签。标签可以是类别标签,例如正常和异常,或者是连续标签,例如心率和心电图形态特征。 然后,你需要创建一个PyTorch dataset类,该类将读取数据文件并将数据加载到PyTorch张量中。以下是一个基本的示例: ```python import torch from torch.utils.data import Dataset, DataLoader class EcgDataset(Dataset): def __init__(self, data_file): # 从数据文件中读取心电信号和标签 self.data, self.labels = self._load_data(data_file) # 将数据和标签转换为PyTorch张量 self.data = torch.tensor(self.data, dtype=torch.float32) self.labels = torch.tensor(self.labels, dtype=torch.int64) def __len__(self): return len(self.labels) def __getitem__(self, index): return self.data[index], self.labels[index] def _load_data(self, data_file): # 从数据文件中读取数据和标签 # ... return data, labels ``` 在上面的示例中,`EcgDataset`是一个继承自PyTorch `Dataset`类的自定义类。在`__init__`方法中,我们从数据文件中读取数据和标签,并将它们转换为PyTorch张量。`__len__`方法返回数据集的大小,`__getitem__`方法返回给定索引的数据和标签。 一旦你创建了`EcgDataset`类,你可以使用PyTorch数据加载器(`DataLoader`)来加载数据,例如: ```python dataset = EcgDataset(data_file='data.csv') dataloader = DataLoader(dataset, batch_size=32, shuffle=True) ``` 这将创建一个数据加载器,它将从`EcgDataset`中加载数据,每批次包含32个样本,并且每个批次都随机洗牌。你可以使用数据加载器来迭代数据,例如: ```python for data, labels in dataloader: # 在这里训练模型或做其他事情 # ... ``` 希望这可以帮助你开始设置你的心电图数据集。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值