【菜菜的CV进阶之路-Pytorch基础-数据处理】自定义数据集加载及预处理

前提:

本文的记录前提是---有一个完整、已调通的pytorch网络项目,因为暂时比赛要用,完整项目等过一段时间再打包发到github上...

比如:加载的pytorch自带cifar数据集:

# train、test图像预处理和增强
transform_train = transforms.Compose(
    [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])

transform_test = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])

#加载train、test数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

数据预处理torchvision.transforms这一部分主要是进行数据的中心化(torchvision.transforms.CenterCrop)、随机剪切(torchvision.transforms.RandomCrop)、正则化、图片变为Tensor、tensor变为图片等。如果不懂请查看官方文档:TORCHVISION.TRANSFORMS

构建Dataset子类

那么如果我想将cifar数据集换成我自己的数据集怎么办呢?答案是:如果想要使用自己的数据,则必须自己构建一个torch.utils.data.Dataset的子类去读取数据。例如:

from __future__ import print_function
import torch.utils.data as data
import torch

class MyDataset(data.Dataset):
    def __init__(self, images, labels):#这一部分用于读取训练、测试数据
        self.images = images
        self.labels = labels

    def __getitem__(self, index):#这一部分将读取的数据输出,返回的是tensor格式
        img, target = self.images[index], self.labels[index]
        return img, target

    def __len__(self):
        return len(self.images)

dataset = MyDataset(images, labels)

查看torchvision.datasets.CIFAR10的源码也可以看到cifar10也是继承了Dataset这个类

在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是__len__和__getitem__。__len__返回数据集的大小,__getitem__实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。 
在创建DataLoader时会判断__getitem__返回值的数据类型,然后用不同的if/else分支把数据转换成tensor,所以,_getitem_返回值的数据类型可选择范围很多,一种可以选择的数据类型是:图像为numpy.array,标记为int数据类型。 

实例:

比如这里我需要读入我自己的32*32大小的图像数据,则代码为:

# 图像预处理和增强
transform_train = transforms.Compose(
    [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])

transform_test = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])

#加载train、test数据集
trainset = MyDataset(imgdir='./Train',imgpath='./Train.txt', transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)

testset = MyDataset(imgdir='./Test',imgpath='./Test.txt', train=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=2)

这里将参数说明一下:MyDataset中imgdir为图片的存放位置,imgpath中包含每张图片的name以及图片的label,其它同cifar

torch.utils.data.DataLoader()函数,合成数据并且提供迭代访问。主要由两部分组成: 

- dataset(Dataset)。输入加载的数据,就是上面的MyDataset的实现。 
- batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等参数,介绍几个比较常用的,这些在官方网站都有:

    - batch-size。样本每个batch的大小,默认为1。
    - shuffle。是否打乱数据,默认为False。
    - num_workers。数据分为几个线程处理默认为0。
    - sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False

之后就是在MyDataset里读取和输出自定义数据集: 简而言之就是在继承了Dataset类的Mydataset里,__init__函数中读取图像数据,DataLoader再通过__getitem__中获取输出

import os
import cv2
from PIL import Image
import numpy as np
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, imgdir, imgpath,train=True,
                 transform=None, target_transform=None):
        self.root = os.path.expanduser(imgdir)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        # now load the picked numpy arrays
        if self.train:
            self.train_data = []#images
            self.train_labels = []#labels
            #read the images and labels in file Train.txt or Test.txt
            with open(imgpath,"r") as imgpath:
                for line in imgpath:
                    line=line.split(' ')
                    image=Image.open(line[0])
                    image = np.array(image)
                    self.train_data.append(image)#将读取的图片放入train_data list里
                    self.train_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里
                imgpath.close()
            self.train_data = np.array(self.train_data)
            self.train_data = self.train_data.reshape((1000,3, 32, 32))
            self.train_data = self.train_data.transpose((0, 2, 3, 1))  # convert to HWC
        else:
            self.test_data = []#images
            self.test_labels = []#labels
            #read the images and labels in file Train.txt or Test.txt
            with open(imgpath,"r") as imgpath:
                for line in imgpath:
                    line = line.split(' ')
                    image=Image.open(line[0])
                    image = np.array(image)
                    self.test_data.append(image)#将读取的图片放入train_data list里
                    self.test_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里
                imgpath.close()
            self.test_data = np.array(self.test_data)
            self.test_data = self.test_data.reshape((200, 3,32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))  # convert to HWC

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img,target

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

参考:

https://blog.csdn.net/GYGuo95/article/details/78821520

https://blog.csdn.net/victoriaw/article/details/72356453

https://pytorch.org/docs/stable/torchvision/transforms.html

【房屋出租】通州北京宇涵文化创意园https://www.douban.com/note/763129847/

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

智慧地球(AI·Earth)社区

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值