PyTorch制作图片数据集

PyTorch制作图片数据集

图片的处理

os.listdir()函数

>>> path = './data/training_set/training_set/cats'
>>> os.listdir(path)
['cat.1.jpg', 'cat.10.jpg', 'cat.100.jpg', 'cat.1000.jpg', 'cat.1001.jpg', 'cat.1002.jpg', 'cat.1003.jpg', 'cat.1004.jpg', 'cat.1005.jpg', 'cat.1006.jpg', 'cat.1007.jpg', 'cat.1008.jpg', 'cat.1009.jpg', 'cat.101.jpg', 'cat.1010.jpg', 'cat.1011.jpg', 'cat.1012.jpg', 'cat.1013.jpg', 'cat.1014.jpg', 'cat.1015.jpg', '
cat.1016.jpg', 'cat.1017.jpg', 'cat.1018.jpg', 'cat.1019.jpg', 'cat.102.jpg', 'cat.1020.jpg', 'cat.1021.jpg', 'cat.1022.jpg', 'cat.1023.jpg', 'cat.1024.jpg', 'cat.1025.jpg', 'cat.1026.jpg', 'cat.1027.jpg', 'cat.1028.jpg', 'cat.1029.jpg', 'cat.103.jpg', 'cat.1030.jpg', 'cat.1031.jpg', 'cat.1032.jpg', 'cat.1033.jpg'
, 'cat.1034.jpg', 'cat.1035.jpg', 'cat.1036.jpg', 'cat.1037.jpg', 'cat.1038.jpg', 'cat.1039.jpg', 'cat.104.jpg', 'cat.1040.jpg', 'cat.1041.jpg', 'cat.1042.jpg', 'cat.1043.jpg', 'cat.1044.jpg', 'cat.1045.jpg', 'cat.1046.jpg', 'cat.1047.jpg', 'cat.1048.jpg', 'cat.1049.jpg', 'cat.105.jpg', 'cat.1050.jpg', 'cat.1051.j
pg', 'cat.1052.jpg', 'cat.1053.jpg', 'cat.1054.jpg', 'cat.1055.jpg', 'cat.1056.jpg', 'cat.1057.jpg', 'cat.1058.jpg', 'cat.1059.jpg', 'cat.106.jpg', 'cat.1060.jpg', 'cat.1061.jpg', 'cat.1062.jpg', 'cat.1063.jpg', 'cat.1064.jpg', 'cat.1065.jpg', 'cat.1066.jpg', 'cat.1067.jpg', 'cat.1068.jpg', 'cat.1069.jpg', 'cat.10
7.jpg', 'cat.1070.jpg', 'cat.1071.jpg', 'cat.1072.jpg', 'cat.1073.jpg', 'cat.1074.jpg', 'cat.1075.jpg', 'cat.1076.jpg', 'cat.1077.jpg', 'cat.1078.jpg', 'cat.1079.jpg', 'cat.108.jpg', 'cat.1080.jpg', 'cat.1081.jpg', 'cat.1082.jpg', 'cat.1083.jpg', 'cat.1084.jpg', 'cat.1085.jpg', 'cat.1086.jpg', 'cat.1087.jpg', 'cat
.1088.jpg', 'cat.1089.jpg', 'cat.109.jpg', 'cat.1090.jpg', 'cat.1091.jpg', 'cat.1092.jpg', 'cat.1093.jpg', 'cat.1094.jpg', 'cat.1095.jpg', 'cat.1096.jpg', 'cat.1097.jpg', 'cat.1098.jpg', 'cat.1099.jpg', 'cat.11.jpg', 'cat.110.jpg', 'cat.1100.jpg', 'cat.1101.jpg', 'cat.1102.jpg', 'cat.1103.jpg', 'cat.1104.jpg', '...]

可以看出作用是返回path路径文件夹下所有文件名的列表

torchvision的transform模块

>>> import torch
>>> from PIL import Image
>>> from torchvision import transforms
>>> img = Image.open('./data/training_set/training_set/cats/cat.1.jpg')
>>> img.size
(300, 280)
>>> transforms.Resize(256)(img).size # 比例改变大小
(274, 256)
>>> transforms.Resize([256,256])(img).size # 结果为图-比例缩小
(256, 256)
>>> transforms.RandomResizedCrop(256)(img).size # 随机切为目标大小
(256, 256)
>>> transforms.RandomSizedCrop([256,200])(img).size # 随机切2
(200, 256)
>>> transforms.Pad(20,1)(img).size # 第一个参数为填充大小,第二个参数为填充值
(340, 320)
>>> transforms.CenterCrop(256)(img).size # 中心切
(256, 256)

原图:
在这里插入图片描述
比例缩小:
在这里插入图片描述

随机切:
在这里插入图片描述
随机切2:
在这里插入图片描述
中心切:
在这里插入图片描述

填充:
在这里插入图片描述

>>> transforms.ToTensor()(img).shape
torch.Size([3, 280, 300])

transforms.ToTensor()函数可以将PIL.Image对象转换为tensor对象
维度由[H, W, C]转为[C, H, W]
H:Height
W:Width
C:Channel
上面的函数一般在ToTensor函数之前进行,下面的函数对tensor进行操作,一般在ToTensor之后进行

>>> imgt = transforms.ToTensor()(img)
>>> imgt.shape
torch.Size([3, 280, 300])
>>> transforms.ToPILImage()(imgt).size # tensor转化为PILImage
(300, 280)
>>> m = imgt.mean(axis = [1,2]) # 求均值,axis=[1,2]可以看作关闭了后两个维度,即在后两个维度上求均值
>>> m
tensor([0.3089, 0.2677, 0.2665])
>>> s = imgt.std(axis = [1,2]) # 求标准差
>>> s
tensor([0.1619, 0.1375, 0.1380])
>>> imgtn = transforms.Normalize(m,s)(imgt) # 对tensor标准化
>>> imgtn.shape
torch.Size([3, 280, 300])
>>> imgtnt = torch.zeros(3,280,300)
>>> imgtnt[0] = (imgt[0]-m[0])/s[0]
>>> imgtnt[1] = (imgt[1]-m[1])/s[1]
>>> imgtnt[2] = (imgt[2]-m[2])/s[2]
>>> imgtnt == imgtn # 标准化过程就是对每个通道output = (input-mean)/std
tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])

多种变化的组合transforms.Compose()
data_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(), # 以0.5的概率随机水平翻转
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

Dataset类的重写

Dataset类源码

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
 
    def __getitem__(self, index):
        raise NotImplementedError
 
    def __len__(self):
        raise NotImplementedError
 
    def __add__(self, other):
        return ConcatDataset([self, other])

由上可以看出__getitem__和__len__方法必须重写
__len__方法用于使用 len(Dataset) 函数
__getitem__方法用于 Dataset[n] 操作
模板可以参考

from torch.utils.data  import Dataset
import pandas as pd  # pandas库提供了读取csv文件的函数read_csv()

class myDataset(Dataset):  # 定义自己的数据类myDataset,继承的抽象类Dataset
    def __init__(self, csv_file, txt_file,root_dir,other_file):  # csv_file:抽象的表示.csv文件;txt_file:抽象的表示txt文件;
                                                                 # root_dir:地址,这些参数放在初始化函数里
        self.csv_data= pd.read_csv(csv_file)   # 读取csv文件,并且赋给他本身
        with open(txt_file,'r') as f:  # 读取txt文件,并且赋给他本身,读取的方式为:with open(...) as f:
            data_list = f.readlines()   # 读取每一行数据,并且放到data_list里
        self.txt_data = data_list
        self.root_dir = root_dir

#  实现下面这个方法:                                                                                                                     
    
    def __len__(self):   # 定义自己的数据类,必须重写这个方法(函数)
        return len(self.csv_data)  # 返回的数据的长度
    
    def __getitem__(self, idx):  # 定义自己的数据类,必须重写这个方法(函数)
        data = (self.csv_data[idx],self.txt_data[idx])   # 获取数据的方式,按照索引进行的 
        return data                                                                     
数据的批量读取

torch.utils.data已经提供的类:Dataset,但是通过这种方式只能一个个的数据的把数据全部读出来,定义了数据读取的方式,不能实现** 批量**的把数据读取出来,为此pytorch有提供了一个方法:DataLoader(),它的参数如下:

from torch.utils.data import DataLoader

dataiter = DataLoader(myDataset,batch_size=32,shuffle=True,collate_fn=default_collate)

myDatase:上面自己定义的数据类
batch_size=32:实现批量读取数据,比如一次取32个数据
shuffle=True:将顺序打乱
collate_fn:表示的是如何读取样本,可以自己定义函数来准确的说明想要实现的功能。
该段来源于该博客

以下是kaggle平台上猫狗识别数据集的制作

import numpy as np
from torchvision import transforms
from PIL import Image
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import os


class catdogDataset(Dataset):
    def __init__(self, file_list, dir, mode='train', transform=None):
        self.file_list = file_list
        self.dir = dir
        self.mode = mode
        self.transform = transform
        if 'dog' in self.file_list[0]:
            self.label = 1  # 标签为1时为狗,标签为0是为猫
        else:
            self.label = 0

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.dir, self.file_list[idx]))
        if self.transform:
            img = self.transform(img)
        if self.mode == 'train':
            img = np.array(img)
            return img.astype('float32'), self.label  # 此处当模式为训练时,返回图片的32位浮点array和它的标签
        else:
            img = np.array(img)
            return img.astype('float32'), self.file_list[idx]


datapathtra = './data/training_set/training_set'
cat_file_list = os.listdir(os.path.join(datapathtra, 'cats'))
dog_file_list = os.listdir(os.path.join(datapathtra, 'dogs'))

data_transform = transforms.Compose([
    transforms.Resize([256, 256]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

catDataset = catdogDataset(cat_file_list, os.path.join(datapathtra, 'cats'), transform=data_transform)
dogDataset = catdogDataset(dog_file_list, os.path.join(datapathtra, 'dogs'), transform=data_transform)

totalDataset = ConcatDataset([catDataset, dogDataset])  # 聚合两个数据集
print(totalDataset[5])  # 此处为__getitem__的用法
print(len(totalDataset))  # 此处为__len__的用法

  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值