1. 读取图像
下面三种方法读取图像得到的是相同的结果,且都是 numpy 类型:
from skimage.io import imread
from PIL import Image
import os, cv2
import numpy as np
def ski_reader(image_path):
image = imread(image_path)
return image
def pil_reader(image_path):
image = Image.open(image_path)
# 如果是png图像,则去掉透明度通道。
if os.path.splitext(image_path)[1] == '.png':
image = image.convert('RGB')
return np.array(image)
def cv2_reader(image_path):
# 使用默认参数读取会忽略透明度通道。返回值类型是numpy.ndarray。
image = cv2.imread(image_path)
# 把默认BGR转为RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
上面读取的 image 的 shape = (h, w, 3),所以三个通道分别是 image[:, :, 0]、image[:, :, 1]、image[:, :, 2]。如果想让 image[0] 为 R 通道、image[1] 为 G 通道、image[2] 为 B 通道,需要使用 image = image.transpose(2, 0, 1) 转换成 shape = (c, h, w)。
使用 OpenCV 读取图像时 flags 参数的作用:
- flags > 0:返回三通道图像。即使原图是单通道灰度图,也返回三个值相等的通道。
- flags = 0:返回单通道图像。
- flags < 0:原样读取图像。如果要读取四维的 RGBA 图像,必须让 flags < 0,否则返回 RGB 图像。如果要读取 uint16 类型的图像,也只能让 flags < 0,否则返回 uint8 类型。
上面三个图像处理库能处理的最大数据类型都是 uint16,且处理 uint16 时只能读取和保存为 png 类型。
2. 读取数据
from torch.utils.data import Dataset, DataLoader
class MyDataSet(Dataset):
def __init__(self, images, labels, reader):
self.images = images
self.labels = labels
self.reader = reader
def __getitem__(self, item):
image = self.reader(self.images[item])
label = self.reader(self.labels[item])
return image, label
def __len__(self):
return len(self.labels)
if __name__ == '__main__':
train_data = MyDataSet(['1.png'], ['2.png'], cv2_reader)
data_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False)
for x, y in data_loader:
print(x.shape) # [b, h, w, c]
print(type(x)) # torch.Tensor
第 3-15 行定义读取数据的类,函数 __getitem__ 的返回值的作为 batch 的元素。第 18、19 行创建数据加载器,第 20 行使用创建的数据加载器读取数据。上面的 MyDataSet 继承 torch.utils.data import Dataset。其实也可以不继承,使用 yield 返回数据。因为 pytorch 的数据维度是 [b, c, h, w],而读取的图像维度一般是 [h, w, c],所以需要一步显式转换。