Pytroch数据集处理以及自定义数据集

为了加强对Pytroch框架的掌握,本文梳理了一下Pytroch自定义数据集常用的知识。首先便是如何从文件夹中读取图片名称或者从txt文档里读取图片名称方便导入图片。接着便是常见的torchvision.transforms的使用以及如何自定义transform。最后便是基于torch.utils.data.Dataset构建自定义数据集,并使用dataloader导入。这里参考了许多网上博客,属于汇总。

一、常见读取文件操作

1.python读取文件夹中的所有图片并将图片名逐行写入txt中

import os

img_path = r'' # 这里需要写入要读取的文件夹路径,这里的r是为了防止路径转义,\\同样也是为了防止\转义
save_txt_path = r'' # 这里需要写入读取后的txt文件要保存的路径和名称

imgs = os.listdir(img_path) # 这里读取文件中图片名称并且将其用列表存储。os.listdir的返回值是一个列表,列表里面存储该path下面的子目录的名称

txt = open(save_txt_path, 'w')

for img in imgs:
    txt.write(img + '\n') # # 逐行写入图片名,'\n'表示换行

txt.close()

这里由于大部分东西都是字符串,故需要用到一些常见的字符串函数

2.python读取txt文本内容

txt_path = r'' 
f = open(txt_path, 't') # 这里的t是一个参数,具体参考下面的图片 

# 第一种read方法,表示一次读取文件全部内容,该方法返回字符串
lines = f.read()
print(lines)
print(type(lines))
f.close()

# 第二种readline方法,该方法每次读出一行内容,该方法返回一个字符串对象
line = f.readline()
while line:
    print(line)
    print(type(line))
    line = f.readline() # 继续读取下一行
f.close()

# 第三种readlines方法,该方法读取整个文件所有行,保存在一个列表(list)变量中,每次读取一行
lines = f.readline()
for line in lines:
    print(line)
    print(type(line))
f.close()

在这里插入图片描述

二、transforms

1.torchvision.transforms

主要用来进行data augmentation操作,是Pytroch中的图像预处理包。

torchvision是pytroch的一个图形库,主要用来构建计算机视觉模型。

  • torchvision.datasets:一些加载数据的函数及常用的数据接口
  • torchvision.models:包含常用的模型结构(含预训练模型),例如ResNet等
  • torchvision.utils:一些常见的辅助工具代码
  • torchvision.transforms:常用的图像处理
transform.Compose() # 把几个常用的变化放到一起
transforms.Compose([
     transforms.CenterCrop(10),
     transforms.ToTensor(),])

transform.Compose这个类会将列表里面的transform操作进行遍历。

其类实现比较简单

class Compose:

    def __init__(self, transforms):
            self.transforms = transforms
	
    # 类实例化之后直接调用
	def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

好了,有一些常见的图像变化,如resize和标准化(transforms.Resize、transforms.Normalize)等,这里就不一一列举了。需要的可以查阅这篇博客

2.自定义transform

首先,如果自定义transform就需要遵循一定的规则。通过查看Compose这个类的实现,我们可以发现自定义transform需要两个约束:

  • 仅接受一个参数,并返回一个参数。如果是多个图片需要同时处理,可以用字典传输
  • 实现需要在 __call__ 中进行

如下为参考代码

from PIL import Image
from torchvision import transforms
from utils import transform_invert
import random
import numpy as np

class Enhance(object):
    """增加椒盐噪声
    Args:
        x():乘
        y (): 加
    """

    def __init__(self, x=1, y=0):
        self.x = x
        self.y = y

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """

        img_ = np.array(img).copy()
        img_ = img_*self.x + self.y
        return Image.fromarray(img_.astype('uint8')).convert('RGB')


if __name__ == '__main__':
    # 1.读取图像
    img = Image.open(r"./cat.png").convert('RGB')


    # 2.确定预处理方式
    img_transform = transforms.Compose([
                        transforms.Grayscale(),
                        Enhance(x=2, y=22),
                        transforms.ToTensor()  # 转Tensor型变量
                                        ])

    img_tensor = img_transform(img)

    # 4.逆Transform变换
    img = transform_invert(img_tensor, img_transform)  # input: shape=[c h w]
    # 5.进行预处理效果展示
    img.show()

多个图片需要处理参考代码如下:

from PIL import Image
import random

class RandomFlipOrRotate(object):
    def __call__(self, sample):
        img1, img2, mask1, mask2, mask_bin = \
            sample['img1'], sample['img2'], sample['mask1'], sample['mask2'], sample['mask_bin']

        rand = random.random()
        if rand < 1 / 6:
            img1 = img1.transpose(Image.FLIP_LEFT_RIGHT)
            mask1 = mask1.transpose(Image.FLIP_LEFT_RIGHT)
            img2 = img2.transpose(Image.FLIP_LEFT_RIGHT)
            mask2 = mask2.transpose(Image.FLIP_LEFT_RIGHT)
            mask_bin = mask_bin.transpose(Image.FLIP_LEFT_RIGHT)
        return {'img1': img1, 'img2': img2, 'mask1': mask1, 'mask2': mask2, 'mask_bin': mask_bin}

# 这里只展示了部分代码
transform = transforms.Compose([
            tr.RandomFlipOrRotate()])
sample = transform({'img1': img1, 'img2': img2, 'mask1': mask1, 'mask2': mask2, 'mask_bin': mask_bin})

三、自定义数据集

1.torch.utils.data.Dataset

同样,如果自定义数据集也必须满足一定的条件:

  • 需要继承data.Dataset

  • __getitem__()__len__()两个方法是必须重写的。__getitem__()输入索引,根据索引返回训练数据,如图片和label,而__len__()返回数据长度。

    class CustomDataset(data.Dataset):#需要继承data.Dataset
        def __init__(self):
            # TODO
            # 1. Initialize file path or list of file names.
            pass
        def __getitem__(self, index):
            # TODO
            # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
            # 2. Preprocess the data (e.g. torchvision.Transform).
            # 3. Return a data pair (e.g. image and label).
            #这里需要注意的是,第一步:read one data,是一个data
            pass
        def __len__(self):
            # You should change 0 to the total size of your dataset.
            return 0
    

2.实例

这是一个实现语义变化检测数据集处理的代码,如果有需要可以私信或评论。因只需要看看如何使用即可,就没有详细介绍。

import datasets.transform as tr

import numpy as np
import os
from PIL import Image
import random
import torch
from torch.utils.data import Dataset
from torchvision import transforms


class ChangeDetection(Dataset):

    def __init__(self, root, mode, use_pseudo_label=False):
        super(ChangeDetection, self).__init__()
        self.root = root

        self.mode = mode
        self.use_pseudo_label = use_pseudo_label

        if mode in ['train', 'val', 'pseudo_labeling']:
            self.root = os.path.join(self.root, 'train')
            self.ids = os.listdir(os.path.join(self.root, "im1"))
            self.ids.sort()
            if mode == 'val':
                self.ids = self.ids[::10]
            else:
                self.ids = list(set(self.ids) - set(self.ids[::10]))
        else:
            self.root = os.path.join(self.root, 'val')
            self.ids = os.listdir(os.path.join(self.root, 'im1'))
        self.ids.sort()

        self.transform = transforms.Compose([
            tr.RandomFlipOrRotate()
        ])

        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

    def __getitem__(self, index):
        id = self.ids[index]

        img1 = Image.open(os.path.join(self.root, 'im1', id))
        img2 = Image.open(os.path.join(self.root, 'im2', id))

        if self.mode == "test":
            img1 = self.normalize(img1)
            img2 = self.normalize(img2)
            return img1, img2, id

        if self.mode == "val":
            mask1 = Image.open(os.path.join(self.root, 'label1', id))
            mask2 = Image.open(os.path.join(self.root, 'label2', id))
        else:
            if self.mode == 'pseudo_labeling' or (self.mode == 'train' and not self.use_pseudo_label):
                mask1 = Image.open(os.path.join(self.root, 'label1', id))
                mask2 = Image.open(os.path.join(self.root, 'label2', id))
            else:
                mask1 = Image.open(os.path.join('outdir/masks/train/im1', id))
                mask2 = Image.open(os.path.join('outdir/masks/train/im2', id))

            if self.mode == 'train':
                gt_mask1 = np.array(Image.open(os.path.join(self.root, 'label1', id)))
                mask_bin = np.zeros_like(gt_mask1)
                mask_bin[gt_mask1 == 0] = 1
                mask_bin = Image.fromarray(mask_bin)

                sample = self.transform({'img1': img1, 'img2': img2, 'mask1': mask1, 'mask2': mask2,
                                         'mask_bin': mask_bin})
                img1, img2, mask1, mask2, mask_bin = sample['img1'], sample['img2'], sample['mask1'], \
                                                     sample['mask2'], sample['mask_bin']

        img1 = self.normalize(img1)
        img2 = self.normalize(img2)
        mask1 = torch.from_numpy(np.array(mask1)).long()
        mask2 = torch.from_numpy(np.array(mask2)).long()
        if self.mode == 'train':
            mask_bin = torch.from_numpy(np.array(mask_bin)).float()
            return img1, img2, mask1, mask2, mask_bin

        return img1, img2, mask1, mask2, id

    def __len__(self):
        return len(self.ids)

3.torch.utils.data.DataLoader

下面显示了 PyTorch 库中DataLoader函数的语法及其参数信息。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

在这里插入图片描述

4.python类常见的魔法方法

这个部分有待补充

四、实战

有了上述基础,我们便可以用Pytroch自定义自己的数据集并使用DataLoader载入了。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值