整理一下利用Pytorch进行数据加载和预处理的实现思路:
主要分以下三种情况:
1 对于torchvision提供的数据集
- 这是最简单的一种情况。
- 对于这一类数据集,就是PyTorch已经帮我们做好了所有的事情,连数据源都不需要自己下载。
- Imagenet,CIFAR10,MNIST等等PyTorch都提供了数据加载的功能,所以可以先看看你要用的数据集是不是这种情况。
import torch import torchvision import torchvision.transforms as transforms transform = transforms.Compose( [transforms.ToTensor(), # 归一化到(0,1),直接除以255 transforms.Normalize(std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5))# 归一化到(-1,1),channel=(channel-mean)/std ] ) train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=4, shuffle=True, num_workers=2) test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform) test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=4, shuffle=False, num_workers=2)
2 对于特定结构的数据集
- 这种情况就是不在上述PyTorch提供数据库之列,但是满足下面的形式:
root/ants/xxx.png root/ants/xxy.jpeg root/ants/xxz.png . . . root/bees/123.jpg root/bees/nsdf3.png root/bees/asd932_.png
- 那么就可以通过
torchvision
中的通用数据集ImageFolder
来完成加载。 - 具体使用方法:
import torch from torchvision import transforms, datasets data_transform = transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train', transform=data_transform) dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset, batch_size=4, shuffle=True, num_workers=4)
3 对于最普通的数据集
- 最后一种情况是既不是自带数据集,又不满足
ImageFolder
,这种时候就自己进行处理。 - 首先,定义数据集的类
(myDataset)
,这个类要继承dataset
这个抽象类,并实现__len__
以及__getitem__
这两个函数,通常情况还包括初始函数__init__
. - 然后,实现用于特定图像预处理的功能,并封装成类。当然常用的一些变换可以在
torchvision
中找到。用torchvision.transforms.Compose
将它们进行组合成(transform)
transform
作为上面myDataset
类的参数传入,并得到实例化myDataset
得到(transformed_dataset)
对象。- 最后,将
transformed_dataset
作为torch.utils.data.DataLoader
类的形参,并根据需求设置自己是否需要打乱顺序,批大小... - 具体见:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html。
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warning
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
######################################################################
# 数据读取
landmarks_frame = pd.read_csv('./data/faces/face_landmarks.csv')
# landmarks_frame.info()
n = 65
img_name = landmarks_frame.iloc[n, 0] # 索引得到第n行、第0列(照片名)
landmarks = landmarks_frame.iloc[n, 1:].as_matrix() # 索引得到第n行、第1~137列(annotation的横纵坐标)
landmarks = landmarks.astype('float').reshape(-1, 2) # reshape为(68,2)的形状,即第一列为散点横坐标,第二列为纵坐标
print('Image name: {}'.format(img_name)) # 查看第n张的照片名
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4])) # 查看前四个点的(x,y)
def show_landmarks(image, landmarks):
# show image with landmarks
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
# plt.pause(0.001) # pause a bit so that plots are updated
plt.pause(3)
plt.figure()
img = io.imread(os.path.join('./data/faces/', img_name))
show_landmarks(img, landmarks)
plt.show()
######################################################################
# 定义人脸标记数据集
#
class FaceLandmarksDataset(Dataset): # 继承Dataset
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self): # __len__返回数据集的大小,用法:len(dataset)
return len(self.landmarks_frame)
def __getitem__(self, idx):# 支持整数idx索引,范围从0到len(self),用法:dataset[i]得到索引为i的样本及标签
img_name = os.path.join(self.root_dir, self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks} # 返回字典形式dict
if self.transform:
sample = self.transform(sample) # 可以实现裁剪缩放等数据转换(transform类是有__call__方法的)
# 所以就可以利用函数形式transform(sample)来进行变换
return sample
######################################################################
# 实例化人脸数据集类、并show出来
face_dataset = FaceLandmarksDataset(csv_file='./data/faces/face_landmarks.csv', root_dir='./data/faces/')
fig = plt.figure()
for i in range(len(face_dataset)):
sample = face_dataset[i] # 实例化第i个样本的image和landmark: 因为有__getitem__ 方法,所以可以根据索引得到样本的字典形式
print(i, sample['image'].shape, sample['landmarks'].shape)
ax = plt.subplot(1, 4, i+1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample) #通常的话本函数需要传入两个参数image, landmarks,但使用此方法可以得到字典中所有键对应的值
if i == 3: # 每四张显示一个figure
plt.show()
break
######################################################################
# 三个transform类(预处理方法)的具体实现:
# 上面返回的图都是原始图像,大小不一,所以一般来说不会直接输入到卷积网络。
# 上面我们在实现自己的dataset类时,可以传入参数transform, 下面我们看一看如何实现transform,并传入到dataset
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is matched to output_size.
If int, smaller of image edges is matched to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size): # 传入的参数为图像输出大小
assert isinstance(output_size, (int, tuple)) # 断言:如果不match,就抛出异常
if isinstance(output_size, int): # 如果为int,例如256,则返回(256,256)大小的图
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2 # 如果为tuple,例如(211,985),则返回(211,985)大小的图
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2] # image.shape (h,w,c)
if isinstance(self.output_size, int): # 当输出size为int时,将此值作为图像的最短边长,而长边则需根据比例进行缩放
if h > w: # h>w,缩放h
new_h, new_w = self.output_size * h/w, self.output_size
else: # h<=w,缩放w
new_h, new_w = self.output_size, self.output_size * w/h
else: # 当输出size为tuple时,直接将此tuple作为图像输出尺寸
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w/w, new_h/h]
return {'image': img, 'landmarks': landmarks} # 注意__getitem__返回的是字典,所以这里也要返回字典
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop is made.
"""
def __init__(self, output_size): # 传入的参数为图像输出大小
assert isinstance(output_size, (int, tuple)) # 断言:如果不match,就抛出异常
if isinstance(output_size, int): # 如果为int,例如256,则返回(256,256)大小的图
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2 # 如果为tuple,例如(211,985),则返回(211,985)大小的图
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2] # image.shape (h,w,c)
new_h, new_w =self.output_size
top = np.random.randint(0, h-new_h)
left = np.random.randint(0, w-new_w)
image = image[top:top+new_h, left:left+new_w]
landmarks = landmarks-[left, top]
return {'image': image,'landmarks':landmarks}
class ToTensor(object): #第三个类转numpy为tensor
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
image = image.transpose((2, 0, 1)) # 转换维度、按照torch格式来
return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}
######################################################################
# Apply each of the above transforms on sample.
scale = Rescale(256) # 实例化第一个类,此时该对象可当做函数使用
crop = RandomCrop(128) # 实例化第二个类,此时该对象可当做函数使用
composed = transforms.Compose([Rescale(256), RandomCrop(224)])
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]): # 试着分别使用这三个函数
transformed_sample = tsfrm(sample) # sample作为参数传入了函数里面,返回image、landmark字典
ax = plt.subplot(1, 3, i+1)
plt.tight_layout()
ax.set_title(type(tsfrm).__name__)
show_landmarks(**transformed_sample)
plt.show()
######################################################################
# Iterating through the dataset
# 根据上文实现的transform,现在我们可以将其放到我们定制的dataset类里面。
# 每当我们的dataset被采样时便会读取一张图片、接着进行transform:
transformed_dataset = FaceLandmarksDataset(
csv_file='./data/faces/face_landmarks.csv',
root_dir='./data/faces/',
transform=transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()])
) # 实例化定制我们自定义的dataset
for i in range(len(transformed_dataset)):
sample = transformed_dataset[i] #每次采样一张图片,其索引为i
print(i, sample['image'].size(), sample['landmarks'].size())
if i == 3:
break
######################################################################
# 在自定义数据集迭代时:上面依靠for循环,每次才能索引一张图,效率低下
# 那么我们需要batch批量数据读入、shuffle打散数据、multiprocessing并行处理时,怎么办?
dataloader = DataLoader(dataset=transformed_dataset, batch_size=4, shuffle=True, num_workers=4)
# Helper function to show a batch
def show_landmarks_batch(sample_batched): # 传进来参数为一个sample对象,自带image和landmarks字典形式
"""Show image with landmarks for a batch of samples."""
images_batch, landmarks_batch = sample_batched['image'], sample_batched['landmarks']
batch_size = len(images_batch)
im_size = images_batch.size(2) # why 2 ?
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
for i in range(batch_size):
plt.scatter(
landmarks_batch[i, :, 0].numpy() + i*im_size,
landmarks_batch[i, :, 1].numpy(),
s = 10, marker='.', c = 'r'
)
######################################################################
#
for i_batch, sample_batched in enumerate(dataloader): # i_batch 可以看做step
print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())
if i_batch == 3: #每个batchdo
plt.figure()
show_landmarks_batch(sample_batched)
plt.axis('off')
plt.ioff()
plt.show()
break