使用pytorch实现对自定义数据集的预处理

任务目标

实现对Pokemon数据集的预处理;(数据集在文末提供)

任务思路

  1. 了解数据集结构;

  1. 训练集、验证集与测试集的预划分;

  1. 数据集的加载。

实现过程

  1. 了解数据集结构

数据集文件夹结构

pokemon下属的每个子文件夹名称即为其下所有图片的label,如下图:

pokemon

├─bulbasaur

│ 00000000.png

│...

├─charmander

│ 00000000.png

│...

├─mewtwo

│ 00000000.jpg

│...

├─pikachu

│ 00000000.jpg

│...

└─squirtle

00000000.png

图片总数为1168,其中

  • 各个类别图片数:

皮卡超(pikachu):234;超梦(mewtwo):239;杰尼龟(squirtle):223;

小火龙(charmander):238;妙蛙种子(bulbasaur):234;

  • 图片格式:

.png:506;.jpg:657;.jpeg:4;.gif:1;

  1. 训练集、验证集与测试集的预划分

train

val

test

60%

20%

20%

  1. 数据集的加载

# 导包
import csv
import glob
import os
import random

import torch
from PIL import Image
from torch.utils.data import Dataset  # 自定义数据集的母类
from torch.utils.data import DataLoader
from torchvision import transforms
  • 搭建基本框架

class Pokemon(Dataset):
    def __init__(self, root, resize):  # 初始化函数
        super(Pokemon, self).__init__()
    
        self.root = root  # 数据集文件夹/pokemon所在路径
        self.resize = resize  

    def __len__(self):  # 获取总样本数量的函数
        pass

    def __getitem__(self, idx):  # 返回当前idx样本数据及label的函数
        pass
  • 设置样本与标签的映射

将各图片与其所属文件夹名称作的label相对应

def __init__(self, root, resize):  # 初始化函数
    super(Pokemon, self).__init__()

    self.root = root  # 数据集文件夹/pokemon所在路径
    self.resize = resize

    self.name2label = {}  # 创建形如{"name":label}的字典
    for name in sorted(os.listdir(os.path.join(root))):  # 将/pokemon下的文件夹名称排序后依次读入
        if not os.path.isdir(os.path.join(root, name)):  # 忽略除文件夹外的其他文件
            continue
 
        self.name2label[name] = len(self.name2label.keys())  # 将文件夹名称与自然数对应
    print(self.name2label)  # 1168, ['pokemon\\bulbasaur\\00000000.png'...]
  • 创建.csv文件

将所有图片样本的路径与其标签以形如image_path(不直接加载图片本身,防止爆内存), label对象的形式保存到一个.csv文件中

def load_csv(self, filename):

    if not os.path.exists(os.path.join(self.root, filename)):  # 判断.csv文件是否存在
        images = []
        # 将所有文件夹下的各种格式的图片以 'pokemon\\bulbasaur\\00000000.png' 的形式保存images列表中
        for name in self.name2label.keys():
            images += glob.glob(os.path.join(self.root, name, '*.png'))
            images += glob.glob(os.path.join(self.root, name, '*.jpg'))
            images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
            images += glob.glob(os.path.join(self.root, name, '*.gif'))

        print(len(images), images)  # 1168, ['pokemon\\bulbasaur\\00000000.png'...]

        random.shuffle(images)  # 将图片顺序随机打乱
        with open(os.path.join(self.root, filename), mode='w', newline='') as f:
            writer = csv.writer(f)
            # 将image列表的信息写入.csv文件中
            for img in images:
                name = img.split(os.sep)[-2]  # 取 'pokemon\\bulbasaur\\00000000.png' 的倒数第二项作为name
                label = self.name2label[name]  # 通过name获得图片的label
                writer.writerow([img, label])  # 以 'pokemon\\bulbasaur\\00000000.png', 0 为一行的形式写入.csv文件中
            print('writen into csv file:', filename)
  • 读取样本与标签

从.csv文件中读取刚刚保存的样本与标签,并将其添加到image和label列表中

    # 接上文
    images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row  # 读取.csv文件的一行
                label = int(label)  # 将label转为int类型

                images.append(img)
                labels.append(label)

            assert len(images) == len(labels)  # 保证样本与标签数目一致,一一对应
            return images, labels
  • 划分训练集、验证集与测试集

依照初始化函数的mode参数划分训练集、验证集与测试集

def __init__(self, root, resize, mode):  # 初始化函数
    super(Pokemon, self).__init__()

    self.root = root  # 数据集文件夹/pokemon所在路径
    self.resize = resize

    self.name2label = {}  # 创建形如{"name":label}的字典
    for name in sorted(os.listdir(os.path.join(root))):  # 将/pokemon下的文件夹名称排序后依次读入
        if not os.path.isdir(os.path.join(root, name)):  # 忽略除文件夹外的其他文件
            continue

        self.name2label[name] = len(self.name2label.keys())  # 将文件夹名称与数字对应
    # print(self.name2label)  # 1168, ['pokemon\\bulbasaur\\00000000.png'...]

    # image, label
    self.images, self.labels = self.load_csv('image.csv')  # 将上文的images和labels读进来

    # 将全部图片按照预处理的比例划分
    if mode == 'train':  # 60% = 0 ~ 60%
        self.images = self.images[:int(0.6 * len(self.images))]
        self.labels = self.labels[:int(0.6 * len(self.labels))]
    elif mode == 'val':  # 20% = 60% ~ 80%
        self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
        self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
    else:  # elif mode == 'test':  # 20% = 80% ~ 100%
        self.images = self.images[int(0.8 * len(self.images)):]
        self.labels = self.labels[int(0.8 * len(self.labels)):]
  • 数据读入

将数据通过__getitem__()读入,并对数据进行简单的预处理

def __len__(self):
    return len(self.images)  # 返回所选集的样本数量

def __getitem__(self, idx):
    # idx:[0 ~ len(images)]
    # self.images, self.labels
    # img: 'pokemon\\squirtle\\00000205.jpg'
    # label: 0
    img, label = self.images[idx], self.labels[idx]
    tf = transforms.Compose([
        lambda x: Image.open(x).convert('RGB'),  # 将path所指的图片读入
  1. 设置图片形状

将图片形状统一,方便数据的读入和统一操作

# 接上文
transforms.Resize((int(self.resize * 1.25), int(self.resize * 1.25))),  
# 放大1.25倍是避免之后旋转裁剪操作导致图片出现黑边
  1. 类型转换

将处理好的数据转换成Tensor数据类型,方便后期使用pytorch框架进行训练

# 接上文
transforms.ToTensor(),
  1. 数据扩充

通过对原有图片的旋转与裁剪,来达到增加数据集规模的效果

# 接上文
transforms.RandomRotation(15),  # 将图片随机旋转15°(旋转角度不能过大)
transforms.CenterCrop(self.resize),  # 将旋转后的图片中心裁剪
  1. 数据归一化

通过x = (x - mean) / std操作将图片像素值控制在0~1之间,让训练更加稳定,loss收敛更快

# 接上文
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
# mean和std的数据来自ImageNet的统计数据,对于图像分类问题具有较好的普适性

外部链接

  1. Pokemon数据集:

链接: https://pan.baidu.com/s/1V_ZJ7ufjUUFZwD2NHSNMFw

提取码:dsxl

  1. 文章项目链接:

链接:https://pan.baidu.com/s/1qg1tw4SBCJaZSUgynz_jwQ

提取码:indf

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值