Pytorch自定义图片数据集

本篇博客旨在实现pytorch读取图片并自定义图片数据集

图像加载方法

主流的图像加载方法主要有三种

下表中xxx表示图片的路径

函数/方法返回值图像像素格式像素值范围图像矩阵表示
skimageio.imread(xxx)numpy.ndarrayRGB[0, 255](H X W X C)
cv2cv2.imread(xxx)numpy.ndarrayBGR[0, 255](H X W X C)
Pillow(PIL)Image.open(xxx)PIL.Image.Image对象根据图像格式,一般为RGB[0, 255]

这里使用三种方式读取一张图片

import matplotlib.pyplot as plt
import skimage.io as io
import cv2
from PIL import Image
import numpy as np
import torch

# width = 1081,  height=1920, channel=3

# 使用skimage读取图像
img_skimage = io.imread('./image/BackGround1.jpg')        # skimage.io imread()-----np.ndarray,  (H x W x C), [0, 255],RGB
print(img_skimage.shape)

# 使用opencv读取图像
img_cv = cv2.imread('./image/BackGround1.jpg')            # cv2.imread()------np.array, (H x W xC), [0, 255], BGR
print(img_cv.shape)

# 使用PIL读取
img_pil = Image.open('./image/BackGround1.jpg')         # PIL.Image.Image对象
img_pil_1 = np.array(img_pil)           # (H x W x C), [0, 255], RGB
print(img_pil_1.shape)

plt.figure()
for i, im in enumerate([img_skimage, img_cv, img_pil_1]):
    ax = plt.subplot(1, 3, i + 1)
    ax.imshow(im)

plt.show()

'''
三种方式输出的shape都是(1081, 1920, 3)
'''

将图片转化为Torch.Tensor

使用np.transpose进行转化,同时要注意numpy和torch中图片维度顺序的不同,因此需要进行转化

  • numpy image: H x W x C
  • torch image: C x H x W
tensor_skimage = torch.from_numpy(np.transpose(img_skimage, (2, 0, 1)))
print(tensor_skimage.shape)
tensor_cv = torch.from_numpy(np.transpose(img_cv, (2, 0, 1)))
print(tensor_cv.shape)
tensor_pil = torch.from_numpy(np.transpose(img_pil_1, (2, 0, 1)))
print(tensor_pil.shape)

'''
输出结果均为torch.Size([3, 1081, 1920])
'''

使用ImageFolder类

ImageFolder是torchvision提供好的一个类,可以让我们直接直接对某一个目录下文件夹内的图片加载为数据集,会自动检测jpg,jpeg,png等图片格式

下面为ImageFolder初始化的源码

super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
                                  transform=transform,
                                  target_transform=target_transform,
                                  is_valid_file=is_valid_file)
  • root:加载的路径,注意如果root=’./’,实际上会检测’./‘的下级目录,比如’./image/'下的图片,不会再 './'直接检测

  • transform:可以添加transforms.Compose()进行图片的预处理,例如

    transforms.Compose(
        [
            transforms.ToTensor()
        ]
    )
    # 可以将图片转化为张量
    

下面展示一个使用ImageFolder的例子

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-V5R05mLq-1633352953512)(C:\Users\Jajison\AppData\Roaming\Typora\typora-user-images\image-20211003220032923.png)]

这是’./image/'下的三张图片

from torchvision import transforms
import torchvision as tv
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

transform = transforms.Compose(
    [
        transforms.ToTensor()
    ]
)

train_set = tv.datasets.ImageFolder(root='./', transform=transform)
data_loader = DataLoader(dataset=train_set)

# transforms提供的类,注意不是方法,需要先实例化,可以torch.tensor转化为PIL.Image.Image对象
to_pil_image = transforms.ToPILImage()
# 因为ToPILImage()类中定义了__call__方法,因此可以使用to_pil_image(xxx)的方式来调用__call__方法,详情可以见源码

for image, label in data_loader:
    # [Batch, Channels, Height, Width]所以第一维度会是1
    print(type(image)) # torch.Size([1, 3, 1081, 1920]) #
    # 下面使用两张展示的方法
    # 方法1:Image.show()
    # 第一种方法会自动打开电脑默认的图片软件来展示图片
    # transforms.ToPILImage()中有一句
    # npimg = np.transpose(pic.numpy(), (1, 2, 0))
    # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
    # image[0]即为torch.Size([3, 1081, 1920])
    img = to_pil_image(image[0])
    img.show()

    # 方法2:plt.imshow(ndarray)
    img = image[0]  # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
    img = img.numpy()  # FloatTensor转为ndarray
    img = np.transpose(img, (1, 2, 0))  # 把channel那一维放到最后

    # 显示图片
    plt.imshow(img)
    plt.show()

重写DataSet类

DataSet是torch中的一个抽象类,用于进行重写自己的类

我们可以重写以下方法

def __getitem__(self, index) -> T_co:
    raise NotImplementedError

def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
    return ConcatDataset([self, other])

def __len__(self)
	return

例如

import os
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

class MyDataset(Dataset):
    def __init__(self, file_path, transform = None):
        super(MyDataset, self).__init__()
        self.file_path = file_path
        self.transform = transform  # 对输入图像进行预处理,这里并没有做,预设为None
        self.image_names = os.listdir(self.file_path)  # 文件名的列表
        print(self.image_names)

    def __getitem__(self, idx):
        image = self.image_names[idx]
        image = io.imread(os.path.join(self.file_path, image))
        if self.transform:
            image= self.transform(image)

        return image

    def __len__(self):
        return len(self.image_names)

在自定义的MyDataSet中,会将file_path路径下的所有图片加入数据集,使用的是skimage.io.read将其写成np.array,可以使用transform将其装变成torch.tensor

将DataSet加载进DataLoader

# 设置自己存放的数据集位置,并尝试转化为PIL.Image.Image对象进行展示
transform = transforms.Compose(
    [
        transforms.ToTensor()
    ]
)
imageloader = MyDataset(file_path="./leaves/",transform = None)

the_dataloader = DataLoader(dataset=imageloader,batch_size=2,shuffle=True)

to_pil_image = transforms.ToPILImage()

for i_batch,batch_data in enumerate(the_dataloader):
    print(i_batch)
    print(len(batch_data)) # 2,即上面设计的batch_size
    for X in batch_data:
        to_pil_image(X.numpy()).show()

在DataLoader中设定dataset,batch_size,和是否打乱shuffle

然后可以通过enumerate来遍历,注意通过enumerate,每一个图片的大小需要一直

其中i_batch表示label,batch_data表示一个batch的数据集

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值