本篇博客旨在实现pytorch读取图片并自定义图片数据集
图像加载方法
主流的图像加载方法主要有三种
下表中xxx表示图片的路径
库 | 函数/方法 | 返回值 | 图像像素格式 | 像素值范围 | 图像矩阵表示 |
---|---|---|---|---|---|
skimage | io.imread(xxx) | numpy.ndarray | RGB | [0, 255] | (H X W X C) |
cv2 | cv2.imread(xxx) | numpy.ndarray | BGR | [0, 255] | (H X W X C) |
Pillow(PIL) | Image.open(xxx) | PIL.Image.Image对象 | 根据图像格式,一般为RGB | [0, 255] | — |
这里使用三种方式读取一张图片
import matplotlib.pyplot as plt
import skimage.io as io
import cv2
from PIL import Image
import numpy as np
import torch
# width = 1081, height=1920, channel=3
# 使用skimage读取图像
img_skimage = io.imread('./image/BackGround1.jpg') # skimage.io imread()-----np.ndarray, (H x W x C), [0, 255],RGB
print(img_skimage.shape)
# 使用opencv读取图像
img_cv = cv2.imread('./image/BackGround1.jpg') # cv2.imread()------np.array, (H x W xC), [0, 255], BGR
print(img_cv.shape)
# 使用PIL读取
img_pil = Image.open('./image/BackGround1.jpg') # PIL.Image.Image对象
img_pil_1 = np.array(img_pil) # (H x W x C), [0, 255], RGB
print(img_pil_1.shape)
plt.figure()
for i, im in enumerate([img_skimage, img_cv, img_pil_1]):
ax = plt.subplot(1, 3, i + 1)
ax.imshow(im)
plt.show()
'''
三种方式输出的shape都是(1081, 1920, 3)
'''
将图片转化为Torch.Tensor
使用np.transpose进行转化,同时要注意numpy和torch中图片维度顺序的不同,因此需要进行转化
- numpy image: H x W x C
- torch image: C x H x W
tensor_skimage = torch.from_numpy(np.transpose(img_skimage, (2, 0, 1)))
print(tensor_skimage.shape)
tensor_cv = torch.from_numpy(np.transpose(img_cv, (2, 0, 1)))
print(tensor_cv.shape)
tensor_pil = torch.from_numpy(np.transpose(img_pil_1, (2, 0, 1)))
print(tensor_pil.shape)
'''
输出结果均为torch.Size([3, 1081, 1920])
'''
使用ImageFolder类
ImageFolder是torchvision提供好的一个类,可以让我们直接直接对某一个目录下文件夹内的图片加载为数据集,会自动检测jpg,jpeg,png等图片格式
下面为ImageFolder初始化的源码
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
-
root:加载的路径,注意如果root=’./’,实际上会检测’./‘的下级目录,比如’./image/'下的图片,不会再 './'直接检测
-
transform:可以添加transforms.Compose()进行图片的预处理,例如
transforms.Compose( [ transforms.ToTensor() ] ) # 可以将图片转化为张量
下面展示一个使用ImageFolder的例子
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-V5R05mLq-1633352953512)(C:\Users\Jajison\AppData\Roaming\Typora\typora-user-images\image-20211003220032923.png)]
这是’./image/'下的三张图片
from torchvision import transforms
import torchvision as tv
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
transform = transforms.Compose(
[
transforms.ToTensor()
]
)
train_set = tv.datasets.ImageFolder(root='./', transform=transform)
data_loader = DataLoader(dataset=train_set)
# transforms提供的类,注意不是方法,需要先实例化,可以torch.tensor转化为PIL.Image.Image对象
to_pil_image = transforms.ToPILImage()
# 因为ToPILImage()类中定义了__call__方法,因此可以使用to_pil_image(xxx)的方式来调用__call__方法,详情可以见源码
for image, label in data_loader:
# [Batch, Channels, Height, Width]所以第一维度会是1
print(type(image)) # torch.Size([1, 3, 1081, 1920]) #
# 下面使用两张展示的方法
# 方法1:Image.show()
# 第一种方法会自动打开电脑默认的图片软件来展示图片
# transforms.ToPILImage()中有一句
# npimg = np.transpose(pic.numpy(), (1, 2, 0))
# 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
# image[0]即为torch.Size([3, 1081, 1920])
img = to_pil_image(image[0])
img.show()
# 方法2:plt.imshow(ndarray)
img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
img = img.numpy() # FloatTensor转为ndarray
img = np.transpose(img, (1, 2, 0)) # 把channel那一维放到最后
# 显示图片
plt.imshow(img)
plt.show()
重写DataSet类
DataSet是torch中的一个抽象类,用于进行重写自己的类
我们可以重写以下方法
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
def __len__(self)
return
例如
import os
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
class MyDataset(Dataset):
def __init__(self, file_path, transform = None):
super(MyDataset, self).__init__()
self.file_path = file_path
self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
self.image_names = os.listdir(self.file_path) # 文件名的列表
print(self.image_names)
def __getitem__(self, idx):
image = self.image_names[idx]
image = io.imread(os.path.join(self.file_path, image))
if self.transform:
image= self.transform(image)
return image
def __len__(self):
return len(self.image_names)
在自定义的MyDataSet中,会将file_path路径下的所有图片加入数据集,使用的是skimage.io.read将其写成np.array,可以使用transform将其装变成torch.tensor
将DataSet加载进DataLoader
# 设置自己存放的数据集位置,并尝试转化为PIL.Image.Image对象进行展示
transform = transforms.Compose(
[
transforms.ToTensor()
]
)
imageloader = MyDataset(file_path="./leaves/",transform = None)
the_dataloader = DataLoader(dataset=imageloader,batch_size=2,shuffle=True)
to_pil_image = transforms.ToPILImage()
for i_batch,batch_data in enumerate(the_dataloader):
print(i_batch)
print(len(batch_data)) # 2,即上面设计的batch_size
for X in batch_data:
to_pil_image(X.numpy()).show()
在DataLoader中设定dataset,batch_size,和是否打乱shuffle
然后可以通过enumerate来遍历,注意通过enumerate,每一个图片的大小需要一直
其中i_batch表示label,batch_data表示一个batch的数据集