学习如何构造和使用数据集类(datasets),转换(transforms)和数据加载器(dataloader)。
from __future__ import print_function, division # 执行精准除法
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = 'TRUE'
import numpy as np
import pandas as pd # 用于更容易地进行csv解析
from skimage import io, transform # 用于图像的IO和变换
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets
import warnings
warnings.filterwarnings("ignore")
def show_landmarks(image, landmarks):
"""显示带有地标的图片"""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
plt.pause(0.001)
def show_landmarks_batch(sample_batch):
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = sample_batch['image'], sample_batch['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2)
grid_border_size = 2
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for idx in range(batch_size):
plt.scatter(landmarks_batch[idx, :, 0].numpy() + idx * im_size + (idx + 1) * grid_border_size,
landmarks_batch[idx, :, 1].numpy() + grid_border_size, s=10, marker='.', c='r')
plt.title('Batch from data_loader')
class FaceLandmarksDataset(Dataset):
"""面部标记数据集"""
def __init__(self, csv_file, root_dir, transform=None):
"""
csv_file(string):带注释的csv文件的路径。
root_dir(string):包含所有图像的目录。
transform(callable, optional):一个样本上的可用的可选变换
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1, 2)
_sample = {'image': image, 'landmarks': landmarks}
if self.transform:
_sample = self.transform(_sample)
return _sample
class Rescale(object):
"""将样本中的图像重新缩放到给定大小.
Args:
output_size(tuple或int):所需的输出大小。如果是tuple,则输出为与output_size匹配。
如果是int,则匹配较小的图像边缘到output_size保持纵横比相同。
example:
input=(640,600) output_size=480 output=(640*480/600,480)=(512,480)
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, _sample):
image, landmarks = _sample['image'], _sample['landmarks'] # image landmarks
h, w = image.shape[:2] # height width
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# x and y axes are axis 1 and 0 respectively, landmark=[x,y]=[col,row]=[axis1,axis0]
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""随机裁剪样本中的图像.
Args:
output_size(tuple或int):所需的输出大小。 如果是int,方形裁剪是。
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, _sample):
image, landmarks = _sample['image'], _sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h, left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""将样本中的ndarrays(多维数组)转换为Tensors"""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# 交换颜色轴因为
# numpy包的图片是: H * W * C
# torch包的图片是: C * H * W
image = image.transpose((2, 0, 1)) # 通道转换
return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)} # 转为torch的tensor格式
if __name__ == "__main__":
print('########### 展示图片 ##########')
landmarks_frames = pd.read_csv('data/faces/face_landmarks.csv')
index = 65
src_name = landmarks_frames.iloc[index, 0]
src_landmarks = landmarks_frames.iloc[index, 1:].values
src_landmarks = src_landmarks.astype('float').reshape(-1, 2)
print('Image name: {}'.format(src_name))
print('Landmarks shape: {}'.format(src_landmarks.shape))
print('First 4 Landmarks:\n{}'.format(src_landmarks[:4]))
plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', src_name)), src_landmarks) # 调用显示函数
plt.show()
print('########### 数据展示 ##########')
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/')
plt.figure()
for i in range(len(face_dataset)):
data_sample = face_dataset[i]
print(i, data_sample['image'].shape, data_sample['landmarks'].shape)
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i)) # set title
ax.axis('off')
show_landmarks(**data_sample)
# show_landmarks(data_sample['image'], data_sample['landmarks'])
if i == 3:
plt.show()
break
print('########### 数据转换 ##########')
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256), RandomCrop(224)]) # 组合转换
fig = plt.figure()
sample = face_dataset[65]
for i, tsf in enumerate([scale, crop, composed]):
transformed_sample = tsf(sample)
ax = plt.subplot(1, 3, i + 1)
plt.tight_layout()
ax.set_title(type(tsf).__name__) # set title
show_landmarks(**transformed_sample)
plt.show()
print('########### 迭代数据集 ##########')
transformed_dataset = FaceLandmarksDataset(
csv_file='data/faces/face_landmarks.csv', root_dir='data/faces/',
transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()]))
for i in range(len(transformed_dataset)):
sample = transformed_dataset[i]
print(i, sample['image'].size(), sample['landmarks'].size())
if i == 3:
break
print('########### 批次迭代数据集 ##########')
data_loader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)
for i_batch, sample_batched in enumerate(data_loader):
print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())
# 观察第4批次并停止
if i_batch == 3:
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break
print('########### 利用torchvision创建数据加载器 ##########')
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# ants_bees_dataset = datasets.ImageFolder(root='ants_bees_data/train', transform=data_transform)
# dataset_loader = DataLoader(ants_bees_dataset, batch_size=4, shuffle=True, num_workers=4)