动手制作自己的图片分类数据集

前言

在使用深度学习进行图片分类任务时, 往往都是用PyTorch进行模型搭建等操作. 如果需要进行的训练任务有现成的图片库, 那么如何将图片变成可以被PyTorch平台所承认的数据集呢? 这就需要自己创建可用的数据集, 就像很多案例中使用PyTorch自带的MNIST, Fashion-MNIST等数据集一样.

Dataset类

Dataset类是数据集的基类, 要想创建自己的数据集就需要自己创建数据集类集成并实现这个基类.
Dataset类位于PyTorch包中, 具体的位置为:

from torch.utils.data import Dataset

utils.data包也有加载器dataloader.
Dataset类的定义如下:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

其中给出了两个函数需要实现, 一个是__getitem__并且是根据下标获取; 另一个则是__add__, 如果只是简单进行分类任务的话, 实现__getitem__函数即可.

数据集类模板

首先创建自己的数据集类, 之后这个数据集类要实现为一个对象从中获取数据集.

class MyDataset(Dataset):
	def __init__(self, img_path):
		pass
	
	def __getitem__(self, idx):
		pass

	def __len__(self):
		pass

这是自定义数据集的模板, 实现这三个函数就可以完成数据集的简单定义.

  • __init__(self, img_path)
    这个函数用来初始化一些与数据集相关的参数. 比如这个数据集的路径信息. 通过这个路径信息可以将数据集中的所有图片都保存到这个类中. 此外, 由于必须要实现的函数__getitem__是通过下标进行获取图片的, 那么势必要将整个数据集中的图片转换为列表的形式, 从而根据下标进行图片的获取.
  • __getitem__(self, idx)
    这个函数是必须要实现的, 这是从数据集中获取数据的接口. 因为要进行训练任务, 那么只返回训练数据(也就是图片)是不够的, 还需要返回相应的标签.
  • __len__(self)
    这个数据可以帮助我们查看数据集中样本的总数量, 毕竟整个数据集的图片都被放到了一个列表结构中, 获取整体的长度也很简单.

实现细节

首先做一个假设. 假设你的数据集中存储的格式是这样的

dataset
	|-label_1
	|--img_01.png
	|--img_02.png
	|--  ...
	|-label_2
	|--img_01.png
	|--img_02.png
	|--  ...
	|-label_3
	|-  ...

就是说dataset文件夹下有许多子文件夹, 这些子文件夹中, 一个子文件夹就是单独一类, 每一类的图片就存放在对应的子文件夹下.

按照这个假设, 如何把整个数据集的图片用列表的形式保存起来, 同时也要保存对应的标签名呢?

在这里我只提供一个思路: 创建一个类, 保存图片的路径也保存图片的标签.
类似于:

class Img():
	def __init__(self, img_path, img_label):
		self.img_path = img_path
		self.img_label = img_label

在保存信息的同时也可以保存其他的信息. 将图片组织成这样的对象存储起来, 然后在后面获取的时候再根据图片的路径创建图片, 再结合标签返回就可以了.
例如:

import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class Img():
	def __init__(self, img_path, img_label):
		self.img_path = img_path
		self.img_label = img_label
	
class MyDataset(Dataset):
	# 这里的dirlist是数据集文件夹的地址
	# 下面的子文件夹才是各个类
	def __init__(self, dirlist):
		self.dirlist = dirlist
		self.label_list = os.listdir(self.dirlist)
		self.img_obj_list = [] # 存储Img对象
		for label in self.label_list:
			# 每一个label都是一个子文件夹名称
			# 和dirlist拼接起来才是图片实际存储的地址
			self.img_dir_path = os.path.join(self.dirlist, label)
			self.img_list = os.listdir(self.img_dir_path)
			for img_path in self.img_list:
				img_obj_list.append(Img(os.path.join(self.img_dir_path, img_path), label))
	
	def __getitem__(self, idx):
		img_obj = self.img_obj_list[idx]
		img = Image.open(img_obj.img_path)
		# 因为导入pytorch中的图片格式必须是tensor或者ndarray, 所以要转换
		# 但要注意: ndarray格式的图片通道为[H, W, C]而非pytorch常用的[C, H, W]
		img =  np.array(img)
		label = img_obj.img_label
		return img, label

	def __len__(self):
		return len(self.img_obj_list)

至此, 就简单完成了数据集的制作

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值