PyTorch学习笔记(Dataset制作图片集)

在Pytorch学习的时候,除了直接使用MNIST数据集训练LeNET网络外,还需要制作自己的数据集,因此在学长的帮助下,终于是读懂了torch.utils.data.Dataset的使用。因为师兄叫我直接从代码入手学习python,就没有关注特别基础的东西,不过这一次也算是稍微弄懂了python中有关类的定义与使用。

使用Dataset得到自己的数据集时,主要应用到的函数有__init__,getitem,__len__这三个,其中__init__是用来自定义各种参数的,其中需要使用super()来保证能够使得新类完全继承data.Dataset中的各种属性;__getitem__中,根据__init__中的路径读取并存储图片,以及图片的标签,返回的是图片及标签值;__len__用来返回图片集的长度,方便__getitem__读取。

__init__用于定义读取图片的路径以及标签文件
__getitem__根据__init__提供的路径读取图片,并最终将图片、标签存入内存中

虽然使用代码来制作数据集,但是前期的工作需要对图片的存放路径、图片名称以及文件夹的名称做一些处理。比如我的数据的总的路径为“D:\Anaconda3\data\tiny-imagenet-200”,里面存放的内容如下图,words.txt文档存放的是标签对应的实物的名称。
在这里插入图片描述
其中train文件夹存放的是训练样本,内容如下图,文件夹名称则是对应图片的标签,每个文件夹下面对应的是具体的训练图片。
在这里插入图片描述
val文件下下面对应的是验证数据,验证数据集中存在一个文档“val_annotations.txt”,文档内容是图片名称(第一列),及其对应的标签(第二列)。
在这里插入图片描述
具体代码及注释如下:
1 模块导入

import os  #用于对文件夹的系列操作
from os.path import join
import torch
import torch.utils.data as data
from torchvision.transforms import transforms
from PIL import Image

2 定义图片处理
因为之后想送入AlexNet中,因此图片大小设置为32*32,其中有关transforms函数的使用推荐参考:transforms函数

#定义训练、测试及验证数据的图片预处理
transform = transforms.Compose([   #transform的系列操作,建议参考https://zhuanlan.zhihu.com/p/53367135
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(32),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(45),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

#定义后缀,用来对图片的查找与读取
#曾出现过错误ValueError: num_samples should be a positive integer value, but got num_samples=0
#因为后缀出错,读取不到图片
#此处需要用(),不能用[],不能是列表数据
FileNameEnd = ('.jpeg', '.JPEG', '.tif', '.jpg', '.png', '.bmp')

3 初始化网络
3.1 训练数据的初始化

class ImageFolder(data.Dataset):
    def __init__(self, root, subdir='train', transform=None):
        super(ImageFolder,self).__init__()

        self.transform = transform  #定义图片转换的类型
        self.image = []     #定义存储图片的列表

        #首先需要得到训练图片的最终路径,用来读取图片,同时需要得到图片对应的文件夹的名称,最为标签数据
        #因此在制作数据集之前,图片存放路径及各个文件夹的命名需要规范
        train_dir = join(root, 'train')  #注意此处不能使用subdir,因为之后的某些值在test及val中也需要使用

        #获取训练文件夹的路径后,train文件夹下面为各种标签命名的文件夹,读取名称作为标签数据
        #sorted可以用来根据名称对读取后的数据排序,得到列表数据
        self.class_names = sorted(os.listdir(train_dir))

        #然后将class_names排序,变成字典,并将序号值与文件夹名称调换位置,使得文件夹名称变为字典的keys数据,数字类型的序号变为values数据
        self.names2index = {v: k for k, v in enumerate(self.class_names)}

        #以上算是制作标签数据的完成,之后需要根据训练、验证、测试数据来具体分析
        #大致的思路是,获取图片具体路径,并将其与标签一一对应,得到多个数组,存入self.image中制作成列表
        #比如self.image[1]可以检索到第二张图片的路径,以及第二张图片的标签形成的数组
        if subdir == 'train':
            for label in self.class_names:
                # 获取文件夹路径,我的路径为:D:/Anaconda3/data/tiny-imagenet-200/train/n01443537
                #其中n01443537为图片对应的文件夹名称,即为标签
                d = join(root, subdir, label)
                #os.walk的用法,遍历文件夹,获取文件的路径,子文件夹的名称,以及文件的名称
                #其中directory为文件夹的初始路径,_表示子文件夹名称,names则是文件名称
                #需要根据具体情况进行修改
                for directory, _, names in os.walk(d):
                    for name in names:
                        filename = join(directory, name)
                        if filename.endswith(FileNameEnd):
                            # 注意此处的双括号,append()可以把数据加到列表后,此处需要的是把数组加进去,因此有append(())
                            self.image.append((filename, self.names2index[label]))

3.2 验证数据及测试数据的初始化

        #验证数据
        #验证数据中的标签数据并不是文件夹名称,存放在txt文档中,因此需要读取txt文档
        if subdir == 'val':
            val_dir = join(root, subdir)
            with open(join(val_dir, 'val_annotations.txt'), 'r') as f:
                infos = f.read().strip().split('\n')[:5000]
                infos = [info.strip().split('\t')[:2] for info in infos]

                self.image = [(join(val_dir, 'images', info[0]), self.names2index[info[1]]) for info in infos]

        #测试数据的读取,测试数据仍然读取的是val文件夹下面的图片,因此test文件下的图片没有被使用
        if subdir == 'test':
            test_dir = join(root, 'val')
            with open(join(test_dir, 'val_annotations.txt'), 'r') as f:
                infos = f.read().strip().split('\n')[5000:]
                infos = [info.strip().split('\t')[:2] for info in infos]

                self.image = [(join(test_dir, 'images', info[0]), self.names2index[info[1]]) for info in infos]

3.3 __getitem__获取图片及标签

    def __getitem__(self, item):
        path, label = self.image[item]
        with open(path, 'rb') as f:    #rb读取二进制文件
            img = Image.open(f).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, label

3.4 __len__返回图片集长度

    def __len__(self):
        return len(self.image)

4 测试验证集中标签是否与训练数据的图片标签对应
从此部分的代码中我懂了__init__中self的意义,我推测加了self的参数存入到了定义的类的属性中,可以在使用类的时候,用类名称将此参数调用出来,比如此部分代码中的TestData.names2index

#测试下验证数据集中的标签能否与训练数据集中的标签对应
if __name__ == '__main__':
    TestData = ImageFolder('D:/Anaconda3/data/tiny-imagenet-200', subdir='train', transform=transform)

    with open('D:/Anaconda3/data/tiny-imagenet-200/val/val_annotations.txt', 'r') as f:
        infos = f.read().strip().split('\n')
        infos = [info.strip().split('\t')[1] for info in infos]

        for classname in infos:
            if not (classname in TestData.names2index):
                print('Sorry!!!')
        print('Yes!!!')
  • 4
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值