PyTorch自制数据集

PyTorch加载数据主要分为两类:只有图片的数据集以及含有csv保存标签的数据集。只有图片的数据集又分为两类:标签在文件夹上和标签在图片名上。

学习地址

1.标签在文件夹上

在这里插入图片描述
此情况下导入数据集,只需要调用PyTorch中的ImageFolder进行载入。(可以直接采用split_data.py划分训练集、测试集、验证集)

导入所需的库

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
#以上语句是由于python与torch版本不匹配才加的与加载数据无关

from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torch
from torchvision import transforms, utils,datasets
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")

数据增强函数

data_transform = transforms.Compose([
 transforms.Resize(32), # 缩放图片(Image),保持长宽比不变,最短边为32像素
 transforms.CenterCrop(32), # 从图片中间切出32*32的图片
 transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
 transforms.Normalize(mean=[0.492, 0.461, 0.417], std=[0.256, 0.248, 0.251]) # 标准化至[-1, 1],规定均值和标准差
])
  • data_transform作用是对图片进行标准化和归一化 Resize(32)缩放图片(Image),保持长宽比不变,最短边为32像素 CenterCrop(32)从图片中间切出32*32的图片 RandomSizedCrop(32)这一句的作用是对原图进行随机大小和高宽比的裁剪,最后的尺寸为32x32
  • RandomHorizontalFlip()这个则是对原图像根据概率进行随机水平翻转
  • transforms.ToTensor()将图片转化为张量,并使图片的形式表现为通道x高x宽的形式
  • transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])这个则是对数据 进行正则化操作,第一个参数为均值,第二个参数为标准差。

计算RGB图片的均值标准差代码如下:

import numpy as np
import cv2
import os

#以下代码用于求RGB图像的均值和标准差
# img_h, img_w = 32, 32
img_h, img_w = 32, 32  # 经过处理后你的图片的尺寸大小
means, stdevs = [], []
img_list = []

imgs_path = "/home/xyjin/PycharmProjects/data_mining/data_test/study/dogs_cats/dogs_cats/data/train/cats"  # 数据集的路径采用绝对引用

imgs_path_list = os.listdir(imgs_path)

len_ = len(imgs_path_list)
i = 0
for item in imgs_path_list:
    img = cv2.imread(os.path.join(imgs_path, item))
    img = cv2.resize(img, (img_w, img_h))
    img = img[:, :, :, np.newaxis]
    img_list.append(img)
    i += 1
    print(i, '/', len_)

imgs = np.concatenate(img_list, axis=3)
imgs = imgs.astype(np.float32) / 255.

for i in range(3):
    pixels = imgs[:, :, i, :].ravel()  # 拉成一行
    means.append(np.mean(pixels))
    stdevs.append(np.std(pixels))

# BGR --> RGB , CV读取的需要转换,PIL读取的不用转换
means.reverse()
stdevs.reverse()

print("normMean = {}".format(means))
print("normStd = {}".format(stdevs))

计算结果如下:
在这里插入图片描述

调用ImageFolder进行数据集的加载

hymenoptera_dataset = datasets.ImageFolder(root="/home/xyjin/PycharmProjects/data_mining/data_test/study/dogs_cats/dogs_cats/data/train",
           transform=data_transform) #导入数据集
  • 第一个参数为数据集路径的参数。
  • 第二个参数为数据增强函数的调用,对加载的数据集进行相关数据操作。

导入数据集后,查看导入情况。

  • 查看图像相关信息
img, label = hymenoptera_dataset[15000] #将启动魔法方法__getitem__(0)
"""这个15000,表示所有文件夹排序后的第15001张图片,0是第一张图片"""
print(label)   #查看标签
"""这里的0表示cat,1表示dog;因为是按文件夹排列的顺序,如果有第三个文件夹pig则2表示pig"""
print(img.size())
print(img)

#处理后的图片信息
for img, label in hymenoptera_dataset:
 print("图像img的形状{},标签label的值{}".format(img.shape, label))
 print("图像数据预处理后:\n",img)
 break

结果:

/home/xyjin/anaconda3/envs/pytorch/bin/python /home/xyjin/PycharmProjects/data_mining/data_test/temp.py
1
torch.Size([3, 32, 32])
tensor([[[-0.1909, -0.2675,  0.0695,  ...,  0.4985,  0.4678,  0.2687],
         [-0.1296, -0.2521,  0.0542,  ...,  0.4525,  0.4066,  0.2381],
         [-0.0530, -0.2368,  0.0389,  ...,  0.4525,  0.4219,  0.2381],
         ...,
         [ 0.0542,  0.1155,  0.2074,  ...,  0.3300,  0.3606,  0.3453],
         [ 0.2534,  0.2840,  0.2840,  ...,  0.2074,  0.3146,  0.3759],
         [ 0.3146,  0.3606,  0.3300,  ...,  0.2534,  0.1615,  0.3453]],

        [[ 0.0070, -0.1353,  0.1810,  ...,  0.8135,  0.7819,  0.3233],
         [ 0.0545, -0.1195,  0.1652,  ...,  0.7661,  0.7502,  0.3075],
         [ 0.1652, -0.1195,  0.1652,  ...,  0.7502,  0.6237,  0.2759],
         ...,
         [ 0.2600,  0.3075,  0.3707,  ...,  0.4498,  0.4972,  0.4972],
         [ 0.4182,  0.4498,  0.4498,  ...,  0.3391,  0.4656,  0.5447],
         [ 0.4972,  0.5289,  0.4972,  ...,  0.3865,  0.2917,  0.4972]],

        [[ 0.0260, -0.0990,  0.1823,  ...,  1.0259,  0.9322,  0.1823],
         [ 0.0885, -0.0834,  0.1510,  ...,  0.9791,  0.8697,  0.1666],
         [ 0.2291, -0.0677,  0.1979,  ...,  0.9009,  0.7291,  0.1198],
         ...,
         [ 0.3541,  0.4010,  0.4635,  ...,  0.4479,  0.4791,  0.4947],
         [ 0.5260,  0.5728,  0.5416,  ...,  0.3229,  0.4635,  0.5260],
         [ 0.5728,  0.6197,  0.6041,  ...,  0.3854,  0.2760,  0.5260]]])
图像img的形状torch.Size([3, 32, 32]),标签label的值0
图像数据预处理后:
 tensor([[[ 1.8159,  1.8618,  1.8925,  ...,  1.9384,  1.9231,  1.9078],
         [ 1.8006,  1.8465,  1.8771,  ...,  1.9384,  1.9231,  1.9231],
         [ 1.7546,  1.8006,  1.8618,  ...,  1.9384,  1.9384,  1.9384],
         ...,
         [-0.8036, -0.7270, -0.7270,  ..., -1.4930, -1.5389, -1.6002],
         [-0.8496, -0.7883, -0.7883,  ..., -0.6351, -1.2172, -1.6155],
         [-0.8343, -0.8496, -0.8343,  ..., -1.1253, -1.3857, -1.6768]],

        [[ 1.3195,  1.3669,  1.4618,  ...,  1.7148,  1.7464,  1.7781],
         [ 1.3037,  1.3511,  1.4144,  ...,  1.7148,  1.7306,  1.7781],
         [ 1.2562,  1.3037,  1.3669,  ...,  1.7148,  1.7464,  1.7939],
         ...,
         [-1.0366, -1.0208, -1.0050,  ..., -1.4952, -1.5268, -1.5901],
         [-1.0524, -1.0524, -1.0524,  ..., -0.8469, -1.3054, -1.5901],
         [-1.0208, -1.0840, -1.0999,  ..., -1.2264, -1.4161, -1.6375]],

        [[ 0.2447,  0.2916,  0.3697,  ...,  0.7135,  0.7291,  0.7760],
         [ 0.2291,  0.2760,  0.3229,  ...,  0.7135,  0.7291,  0.8072],
         [ 0.1823,  0.2291,  0.2760,  ...,  0.6978,  0.7447,  0.8541],
         ...,
         [-1.5520, -1.5051, -1.5051,  ..., -1.5989, -1.5989, -1.6145],
         [-1.5676, -1.5364, -1.5207,  ..., -1.4114, -1.5520, -1.6145],
         [-1.5676, -1.5676, -1.5520,  ..., -1.5051, -1.5832, -1.6145]]])

Process finished with exit code 0

对加载的数据进行batch size处理:表示加载的数据切分为4个为一组的 shuffle=True,送入训练的的数据是打乱后的,而不是顺序输入。

dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4,shuffle=True)
  • 使用显示图片与对应标签的方式进行查验
import torchvision
import matplotlib.pyplot as plt
import numpy as np
# %matplotlib inline
# 显示图像
def imshow(img):
 img = img / 2 + 0.5  # unnormalize
 npimg = img.numpy()
 plt.imshow(np.transpose(npimg, (1, 2, 0)))
 plt.show()
# 随机获取部分训练数据
dataiter = iter(dataset_loader)#此处填写加载的数据集
images, labels = dataiter.next()
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join('%s' % ["小狗" if labels[j].item()==1 else "小猫" for j in range(4)]))

图片结果:
在这里插入图片描述
[ ’ 小 狗 ’ , ’ 小 狗 ’ , ’ 小 猫 ’ , ’ 小 狗 ’ ]

总结:标签在文件夹上的数据载入,代码如下:

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import torch
from torchvision import transforms, utils,datasets
from PIL import Image
import pandas as pd
import numpy as np
#过滤警告信息
import warnings
warnings.filterwarnings("ignore")

data_transform = transforms.Compose([
 transforms.Resize(32), # 缩放图片(Image),保持长宽比不变,最短边为32像素
 transforms.CenterCrop(32), # 从图片中间切出32*32的图片
 transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
 transforms.Normalize(mean=[0.492, 0.461, 0.417], std=[0.256, 0.248, 0.251]) # 标准化至[-1, 1],规定均值和标准差
])

hymenoptera_dataset = hymenoptera_dataset = datasets.ImageFolder(root="/home/xyjin/PycharmProjects/data_mining/data_test/study/dogs_cats/dogs_cats/data/train",
           transform=data_transform) #导入数据集
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,batch_size=4,shuffle=True) #

2. 标签在图片名上

3. 将数据集分为训练集,验证集和测试集

4.标签存储在csv文件中

数据集介绍

  • 本数据集加载使用的数据集是给狗进行种类的识别的数据集
    在这里插入图片描述
    该文件夹中有:测试集,训练集,还有存有训练集标签和对应标签的文件名的csv文件
    在这里插入图片描述

数据集加载的方式和第一种相同,需要将每一个种类的数据照片放到对应种类命名的文件夹中

首先进行数据集的拆分,导入对应的库;并设置用到的一些变量

import math
import os
import shutil
from collections import Counter

data_dir = "/home/xyjin/PycharmProjects/data_mining/data_test/study/dog-breed-identification" #数据集的根目录
label_file = 'labels.csv'#根目录中csv的文件名加后缀
train_dir = 'train'#根目录中的训练集文件夹的名字
test_dir = 'test'#根目录中的测试集文件夹的名字
input_dir = 'train_valid_test'#用于存放拆分数据集的文件夹的名字,可以不用先创建,会自动创建
batch_size = 4#送往训练的一批次中的数据集的个数
valid_ratio = 0.1#将训练集拆分为90%为训练集10%为验证集
  • 训练集拆分为训练集(test)和验证集(valid)然后分别放到对应的文件夹中
    在这里插入图片描述

  • train_valid文件夹中包含了所有训练集、验证集的数据
    在这里插入图片描述

  • 训练集、验证集和train_valid的文件夹中都是一个个种类的小文件夹,其中存放着对应的数据集图像

  • test中没有标签所以所有的数据照片都存放在unknown中的

程序如下:

def reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,
                   valid_ratio):
    # 读取训练数据标签,label.csv文件读取标签以及对应的文件名
    with open(os.path.join(data_dir, label_file), 'r') as f:
        # 跳过文件头行(栏名称)
        lines = f.readlines()[1:]
        tokens = [l.rstrip().split(',') for l in lines]
        idx_label = dict(((idx, label) for idx, label in tokens))
    labels = set(idx_label.values())

    num_train = len(os.listdir(os.path.join(data_dir, train_dir)))#获取训练集的数量便于数据集的分割
    # 训练集中数量最少一类的狗的数量
    min_num_train_per_label = (
        Counter(idx_label.values()).most_common()[:-2:-1][0][1])
    # 验证集中每类狗的数量
    num_valid_per_label = math.floor(min_num_train_per_label * valid_ratio)
    label_count = dict()

    def mkdir_if_not_exist(path):#判断是否有存放拆分后数据集的文件夹,没有就创建一个
        if not os.path.exists(os.path.join(*path)):
            os.makedirs(os.path.join(*path))

    # 整理训练和验证集,将数据集进行拆分复制到预先设置好的存放文件夹中。
    for train_file in os.listdir(os.path.join(data_dir, train_dir)):
        idx = train_file.split('.')[0]
        label = idx_label[idx]
        mkdir_if_not_exist([data_dir, input_dir, 'train_valid', label])
        shutil.copy(os.path.join(data_dir, train_dir, train_file),
                    os.path.join(data_dir, input_dir, 'train_valid', label))
        if label not in label_count or label_count[label] < num_valid_per_label:
            mkdir_if_not_exist([data_dir, input_dir, 'valid', label])
            shutil.copy(os.path.join(data_dir, train_dir, train_file),
                        os.path.join(data_dir, input_dir, 'valid', label))
            label_count[label] = label_count.get(label, 0) + 1
        else:
            mkdir_if_not_exist([data_dir, input_dir, 'train', label])
            shutil.copy(os.path.join(data_dir, train_dir, train_file),
                        os.path.join(data_dir, input_dir, 'train', label))

    # 整理测试集,将测试集复制存放在新建路径下的unknown文件夹中
    mkdir_if_not_exist([data_dir, input_dir, 'test', 'unknown'])
    for test_file in os.listdir(os.path.join(data_dir, test_dir)):
        shutil.copy(os.path.join(data_dir, test_dir, test_file),
                    os.path.join(data_dir, input_dir, 'test', 'unknown'))

reorg_dog_data(data_dir, label_file, train_dir, test_dir, input_dir,valid_ratio)
  • 8
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值