pytorch快速上手(3)-----如何处理自己的数据集

转载自博客:https://blog.csdn.net/Teeyohuang/article/details/79587125

之前讲的例子,程序都是调用的datasets方法,下载的torchvision本身就提供的数据,那么如果想导入自己的数据应该怎么办呢?

本篇就讲解一下如何创建自己的数据集。

1.用于分类的数据集

以mnist数据集为例

这里的mnist数据集并不是torchvision里面的,而是我自己的以图片格式保存的数据集,因为我在测试STN时,希望自己再把这些手写体做一些形变,

所以就先把MNIST数据集转化成了jpg图片格式,然后做了一些形变,当然这不是重点。首先我们看一下我的数据集的情况:
在这里插入图片描述
如图所示,我的图片数据集确实是jpg图片
再看我的存储图片名和label信息的文本:
在这里插入图片描述
如图所示,我的mnist.txt文本每一行分为两部分,第一部分是具体路径+图片名.jpg

第二部分就是label信息,因为前面这部分图片都是0 ,所以他们的分类的label信息就是0

要创建你自己的 用于分类的 数据集,也要包含上述两个部分,1.图片数据集,2.文本信息(这个txt文件可以用python或者C++轻易创建,再此不详述)

主要代码

from PIL import Image
import torch
 
class MyDataset(torch.utils.data.Dataset): #创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
    def __init__(self,root, datatxt, transform=None, target_transform=None): #初始化一些需要传入的参数
        super(MyDataset,self).__init__()
        fh = open(root + datatxt, 'r') #按照传入的路径和txt文本参数,打开这个文本,并读取内容
        imgs = []                      #创建一个名为img的空列表,一会儿用来装东西
        for line in fh:                #按行循环txt文本中的内容
            line = line.rstrip()       # 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python
            words = line.split()   #通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
            imgs.append((words[0],int(words[1]))) #把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定
                                        # 很显然,根据我刚才截图所示txt的内容,words[0]是图片信息,words[1]是lable
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
 
    def __getitem__(self, index):    #这个方法是必须要有的,用于按照索引读取每个元素的具体内容
        fn, label = self.imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
        img = Image.open(root+fn).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片
 
        if self.transform is not None:
            img = self.transform(img) #是否进行transform
        return img,label  #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
 
    def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.imgs)
 
#根据自己定义的那个勒MyDataset来创建数据集!注意是数据集!而不是loader迭代器
train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
#然后就是调用DataLoader和刚刚创建的数据集,来创建dataloader,这里提一句,loader的长度是有多少个batch,所以和batch_size有关
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
 

再补充一点代码,以便更好的理解 __getitem__这个方法

for batch_index, data, target in test_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

这段代码是我从测试的部分中截取出来的,为什么直接能用for data, target In test_loader这样的语句呢?

其实这个语句还可以这么写:

for batch_index, batch in train_loader
        data, target = batch

这样就好理解了,因为这个迭代器每一次循环所得的batch里面装的东西,就是我在__getitem__方法最后return回来的

所以你想在训练或者测试的时候还得到其他信息的话,就去增加一些返回值即可,只要是能return出来的,就能在每个batch中读取到!

有朋友可能想问,如果我的label信息不是数字而是图像呢?比如分割任务,它的label就是图像,这样的数据集的建立,请继续往下看~

2.分割任务(标签是图像)

上面是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签,本篇介绍输入的标签label亦为图像的数据集,并包含一些常用的处理手段。比如做图像语义分割时就会用到这种数据输入方式。

1、数据集简介

以VOC2012数据集为例,图像是RGB3通道的,label是1通道的,(其实label原来是几通道的无所谓,只要读取的时候转化成灰度图就行)。
训练数据:
在这里插入图片描述
语义label:
在这里插入图片描述
这里我们看到label图片都是黑色的,只有白色的轮廓而已。

其实是因为label图片里的像素值取值范围是0 ~ 20,即像素点可能的类别共有21类(对此数据集来说),详情如下:

在这里插入图片描述
所以对于灰度值0—20来说,我们肉眼看上去就确实都是黑色的,因为灰度值太低了,而白色的轮廓的灰度值是255!

但是这些边界在计算损失值的时候是不作为有效值的,也就是对于灰度值==255的点是忽略的

如果想看的话,可以用一些色彩变换,对0–20这每一个数字对应一个色彩,就能看出来了,示例如下:
在这里插入图片描述
这不是重点,只是给大家看一下方便理解而已。

2、文本信息

同样有一个文本来指导我对数据的读取,我的信息如下:
在这里插入图片描述
这其实就是一个记载了图像ID的文本文档,连后缀都没有,但我们依然可以根据这个去数据集中读取相应的image和label。

3.示例代码

这个代码是我自己在利用deeplabV2 跑semantic segmentation 任务时写的一个,也许写的并不优美,但反正是可以用的,

可以做个抛砖引玉的目的,对于才入门的朋友,理解这个思路就可,不必照搬我的代码风格……

import os
import numpy as np
import random
import matplotlib.pyplot as plt
import collections
import torch
import torchvision
import cv2
from PIL import Image
import torchvision.transforms as transforms
from torch.utils import data
 
class VOCDataSet(data.Dataset):
    def __init__(self, root, list_path,  crop_size=(321, 321), mean=(104.008, 116.669, 122.675), mirror=True, scale=True, ignore_label=255):
        super(VOCDataSet,self).__init__()
        self.root = root
        self.list_path = list_path
        self.crop_h, self.crop_w = crop_size
        self.ignore_label = ignore_label
        self.mean = np.asarray(mean, np.float32)
        self.is_mirror = mirror
        self.is_scale = scale
 
        self.img_ids = [i_id.strip() for i_id in open(list_path)]
 
        self.files = []
        for name in self.img_ids:
            img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
            label_file = os.path.join(self.root, "SegmentationClassAug/%s.png" % name)
            self.files.append({
                "img": img_file,
                "label": label_file,
                "name": name
            })
 
    def __len__(self):
        return len(self.files)
 
 
    def __getitem__(self, index):
        datafiles = self.files[index]
 
        '''load the datas'''
        name = datafiles["name"]
        image = Image.open(datafiles["img"]).convert('RGB')
        label = Image.open(datafiles["label"]).convert('L')
        size_origin = image.size # W * H
 
        '''random scale the images and labels'''
        if self.is_scale: #如果我在定义dataset时选择了scale=True,就执行本语句对尺度进行随机变换
            ratio = 0.5 + random.randint(0, 11) // 10.0 #0.5~1.5
            out_h, out_w = int(size_origin[1]*ratio), int(size_origin[0]*ratio)
            # (H,W)for Resize
            image = transforms.Resize((out_h, out_w), Image.LANCZOS)(image)
            label = transforms.Resize((out_h, out_w), Image.NEAREST)(label)
 
        '''pad the inputs if their size is smaller than the crop_size'''
        pad_w = max(self.crop_w - out_w, 0)
        pad_h = max(self.crop_h - out_h, 0)
        img_pad = transforms.Pad( padding=(0,0,pad_w,pad_h), fill=0, padding_mode='constant')(image)
        label_pad = transforms.Pad( padding=(0,0,pad_w,pad_h), fill=self.ignore_label, padding_mode='constant')(label)
        out_size = img_pad.size
 
        '''random crop the inputs'''
        if (self.crop_h != 0 or self.crop_w != 0):
            #select a random start-point for croping operation
            h_off = random.randint(0, out_size[1] - self.crop_h)
            w_off = random.randint(0, out_size[0] - self.crop_w)
            #crop the image and the label
            image = img_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))
            label = label_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))
 
        '''mirror operation'''
        if self.is_mirror:
            if np.random.random() < 0.5:
                #0:FLIP_LEFT_RIGHT, 1:FLIP_TOP_BOTTOM, 2:ROTATE_90, 3:ROTATE_180, 4:or ROTATE_270.
                image = image.transpose(0)
                label = label.transpose(0)
 
        '''convert PIL Image to numpy array'''
        I = np.asarray(image,np.float32) - self.mean
        I = I.transpose((2,0,1))#transpose the  H*W*C to C*H*W
        L = np.asarray(np.array(label), np.int64)
        #print(I.shape,L.shape)
        return I.copy(), L.copy(), np.array(size_origin), name
 
#这是一个测试函数,也即我的代码写好后,如果直接python运行当前py文件,就会执行以下代码的内容,以检测我上面的代码是否有问题,这其实就是方便我们调试,而不是每次都去run整个网络再看哪里报错
if __name__ == '__main__':
    DATA_DIRECTORY = '/home/teeyo/STA/Data/voc_aug/'
    DATA_LIST_PATH = '../dataset/list/val.txt'
    Batch_size = 4
    MEAN = (104.008, 116.669, 122.675)
    dst = VOCDataSet(DATA_DIRECTORY,DATA_LIST_PATH, mean=(0,0,0))
    # just for test,  so the mean is (0,0,0) to show the original images.
    # But when we are training a model, the mean should have another value
    trainloader = data.DataLoader(dst, batch_size = Batch_size)
    plt.ion()
    for i, data in enumerate(trainloader):
        imgs, labels,_,_= data
        if i%1 == 0:
            img = torchvision.utils.make_grid(imgs).numpy()
            img = img.astype(np.uint8) #change the dtype from float32 to uint8, because the plt.imshow() need the uint8
            img = np.transpose(img, (1, 2, 0))#transpose the Channels*H*W to  H*W*Channels
            #img = img[:, :, ::-1]
            plt.imshow(img)
            plt.show()
            plt.pause(0.5)
 
            #input()

我个人觉得我应该注释的地方都有相应的注释,虽然有点长, 因为实现了crop和翻转以及scale等功能,但是大家可以下去慢慢揣摩,理解其中的主要思路,与上面模块1分类任务部分做对比,那部分相当于是提供了最基本的骨架,而这下面的内容就在骨架上长肉生发而已,有疑问的欢迎评论探讨~~

其他像多任务模型(分类+检测+ReID)的数据集处理类似,都是在此框架上丰富一下处理的过程,只要return的内容是你想要的就行,对应后面dataloader取出后,前向推理,算loss反传能取对就行,代码实现各异,但思路是一致的,殊途同归~

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值