【Deep learning with pytorch 自学教程】数据集处理:Dataset&Dataloder 的使用
前言
在看过了许许多多的科研大佬的论文之后,发现pytorch 被越来越多人使用喜爱,于是也想尝试使用pytorch开始我的科研之路。这一系列博客主要参考pytorch官网自带的教程,并模仿教程中的代码实现建议对照着网站的教程看这篇博客。
数据集
数据集的读取与预处理对于深度学习是必不可少的部分。在tensorflow中数据的读取需要转化为tfrecord形式的文件进行读取,而在pytorch中将数据的读取打包为一个类(Dataset Class)在该类的基础上可以拓展子父类,实现对数据集的预处理和迭代读取。
打包数据集(Dataset Class)
torch.utils.data.Dataset是一个代表数据集的抽象类,想要为自己的神经网络打包一个数据集可以继承这个抽象类,并需要重写这个类中以下几个函数:
__len__函数:len(dataset)可以返回数据集的大小;
__getitem__函数: 实现通过下表读取数据集,如:dataset[i]返回数据集中第i个数据
打包数据集可以这样安排,在_ _init _ _函数中写入数据的读取路径与读取方式(如csv文件可以通过pd.read_csv(file)读取,npy文件可以通过np.load(file)读取,但在这里还没有真正的读取,只是定义读取的方式,真正的读取在getitem之中),在__getitem__函数中将具体的文件读出来。这样可以节省内存,不需要将所有的数据都放在内存里,需要哪一个读哪一个就好了。在getitem函数之中,返回的数据形式可以自定义,比如在给出的例子中,可以定义为{‘image’:image,‘landmarks’:landmark}这样的形式。
具体代码:
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset.
这是一个标记人脸特殊点的数据集,数据的存储方式为.csv
数据的存储形式为:
图片名称 标记点1_x 标记点1_y 标记点2_x 标记点2_y ......
"""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)#定义数据集的读取方式
self.root_dir = root_dir
self.transform = transform #这里是对数据集的预处理,后面会讲到
def __len__(self):
return len(self.landmarks_frame)#定义len函数,返回数据集的大小
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()#将索引转化为链表的形式
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name) #读取图片
landmarks = self.landmarks_frame.iloc[idx, 1:] #读取所有的标记点
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1, 2)#将每一个点的横纵坐标列为一行
sample = {'image': image, 'landmarks': landmarks}#定义数据返回的形式
if self.transform:
sample = self.transform(sample)
return sample
具体的使用如下:
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/')#实例化之前定义的类
fig = plt.figure()
for i in range(len(face_dataset)):#运用len函数得到整个数据集的长度
sample = face_dataset[i]#通过下标的形式调用getitem函数
print(i, sample['image'].shape, sample['landmarks'].shape)#这里返回的是形式和之前定义的一样
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)
if i == 3:
plt.show()
break
结果是:
0 (324, 215, 3) (68, 2)
1 (500, 333, 3) (68, 2)
2 (250, 258, 3) (68, 2)
3 (434, 290, 3) (68, 2)
预处理(Transforms)
数据的预处理在之前打包数据集的时候也有提到,数据的预处理可以包括如图片的尺度变化,数据增强,将图片转换为torch的形式。
为了不每次都将数据预处理所需要的参数都传入到函数中,一般将数据的预处理写作一个可以调用的类的形式。运用__call__函数和__init__函数实现,其中__init__函数可以定义预处理的所需要的参数,__call__函数可以用来传入数据,调用处理函数。
tsfm = Transform(params) #实例化预处理方式
transformed_sample = tsfm(sample) #传入数据,得到预处理之后的transformed_sample
具体实现的预处理函数如下:
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size)