使用 PyTorch 读取自己的数据集

1. Dataset 和 DataLoader

我们经常见到这样一段代码:

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

其中,datasets 是 torchvision 的一个模块,通过它可以导入各种像 MINIST 等常用的数据集; torch.utils.data.DataLoader 是 torch 提供的一个有关划分的模块。

对应 自定义数据集,我们可以使用 torch.utils.data.Dataset torch.utils.data.DataLoader 两个模块对数据集处理


2. torch.utils.data.Dataset

类似 torchvision 中的 datasets, torch.utils.data.Dataset 可以加载数据集,并对数据经行必要的 transform,Dataset 的官方说明如下:

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite getitem(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite len(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

Dataset 类似于 C++ 的虚基类,其中的函数无具体定义,需要重载才能使用,我们需要重载以下方法:

  • __len__(self): 返回 example 的数目,返回整数
  • __getitem__(self, index): 根据一个 index 返回一条 example,返回一个 tuple(X, y)

3. 使用 CSV 存储样本的位置和类别

通常会使用一个 csv 存储样本的路径和类别,训练集和测试机分别对应一个 csv 文件:

此时在 Dataset 的 __init__(self) 中需要传入 csv_path, 以及数据集需要的做的 transform,接着读取 csv 中的样本 X 的路径,以及类别信息,计算样本数目

import torch
import argparse
import numpy as np
import pandas as pd
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader


class ReadDataFromCSV(Dataset):
    def __init__(self, csv_path, transform):
        # Transforms
        self.transform = transform
        
        # read csv
        self.data_info = pd.read_csv(csv_path, header=None)
        self.image_arr = np.asarray(self.data_info.iloc[1:, 0])
        self.label_arr = np.asarray(self.data_info.iloc[1:, 1])

        self.data_len = len(self.data_info.index)

接着,重载 __getitem__ 函数

  • getitem 通过 index 读取图片,并使用 self.transform() 对图片预处理(裁剪、缩放、转换为 tensor、归一化等)
  • getitem 通过 index 得到样本的 label 信息
  • 返回一个 tuple (image_tensor, label)
    def __getitem__(self, index):
        # get image
        single_img_name = self.image_arr[index]
        single_img_img = Image.open('../' + single_img_name)
        single_img_tensor = self.transform(single_img_img)

        # get label
        single_image_label = self.label_arr[index]

        return (single_img_tensor, single_image_label)

注意:csv_path 需要根据具体代码和 csv 的相对路径修改,图片相对代码的位置和csv中给出位置也要自适应调整一下!

然后,重构 __len__(self)

    def __len__(self):
        return self.data_len

4. 测试结果

接着测试一下,使用

if __name__ == '__main__':

	csv_path = '../SARimage/train.csv'
	torch_data = ReadDataFromCSV(
		csv_path='../SARimage/train.csv',
		transform=transforms.Compose(
			[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
		)
    )

	for i, (X, y) in enumerate(torch_data):
		print(f'index:{i}\nX:{X}\ny:{y}')
		break

上面使用了 transforms 的 Resize、ToTensor、Normalize([0.5], [0.5]) 对数据进行变换,输出结构如下:

index:0
X:tensor([[[-0.7333, -0.7176, -0.7725, -0.7098, -0.7255, -0.7569, -0.7412,
          -0.7490, -0.7255, -0.7961, -0.7804, -0.7333, -0.7569, -0.7255,
          -0.7647, -0.7804, -0.7882, -0.7569, -0.7490, -0.8118, -0.7882,
          -0.7961, -0.7961, -0.7725, -0.7020, -0.6549, -0.6627, -0.7569],
         [-0.7569, -0.7333, -0.7725, -0.7412, -0.7412, -0.8039, -0.7569,
          -0.6235, -0.7569, -0.7255, -0.7412, -0.7098, -0.7255, -0.7020,
          -0.6471, -0.6784, -0.6941, -0.6471, -0.6784, -0.8039, -0.8196,
          -0.8118, -0.7961, -0.7490, -0.7490, -0.7020, -0.6784, -0.7333],
         ...
         [-0.7725, -0.7961, -0.7333, -0.7647, -0.7804, -0.8353, -0.8275,
          -0.7725, -0.7961, -0.8118, -0.8118, -0.8039, -0.8118, -0.8039,
          -0.8118, -0.7961, -0.7176, -0.7647, -0.7647, -0.7412, -0.7804,
          -0.7647, -0.7569, -0.7725, -0.8510, -0.8353, -0.7961, -0.7647]]])
y:2S1

具体使用 DataLoader 划分数据集方法如下:

dataloader = torch.utils.data.DataLoader(
    ReadDataFromCSV(
        csv_path = '../../SARimage/train.csv',
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
    drop_last=False,
)


5. label 信息

当没有 csv 文件时,可以使用 pathlib2 读取文件夹,遍历其中的所有文件,用文件夹名称作为label;

当我们需要使用 label 时,如果 label 是一个 str,则需要将 label 变换成数字类型,比如使用一个 dict 映射:

self.label_dict = {'2S1': 0, 'BMP2': 1, 'BRDM_2': 2,
      'BTR_60': 3, 'BTR70': 4, 'D7': 5, 'T62': 6, 'T72': 7, 'ZIL131': 8, 'ZSU_23_4': 9}

single_image_label = self.label_dict[self.label_arr[index]]

在 CGAN 中,我们可以这样子做:

class ReadDataFromCSV(Dataset):
    def __init__(self, csv_path, transform):
        # Transforms
        self.transform = transform
        
        # read csv
        self.data_info = pd.read_csv(csv_path, header=None).iloc[1:, :]
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])

        self.data_len = len(self.data_info.index)
        self.label_dict = {'2S1': 0, 'BMP2': 1, 'BRDM_2': 2,
        'BTR_60': 3, 'BTR70': 4, 'D7': 5, 'T62': 6, 'T72': 7, 'ZIL131': 8, 'ZSU_23_4': 9}

        
    def __getitem__(self, index):
        # get image
        single_image_name = self.image_arr[index]
        single_img_img = Image.open('../../' + single_image_name)
        single_img_tensor = self.transform(single_img_img)

        # get label
        single_image_label = self.label_dict[self.label_arr[index]]

        return (single_img_tensor, single_image_label)


    def __len__(self):
        return self.data_len


# 读取
dataloader = torch.utils.data.DataLoader(
    ReadDataFromCSV(
        csv_path = '../../SARimage/train.csv',
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
    drop_last=False,
)

REFERENCES:

  • 4
    点赞
  • 35
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值