这些天看的东西,真的是比较多,相比以前来说,对我的学习方式起到颠覆性作用。我目前觉得,我们学到的东西,更多是孤立的,因此,在吸收一定知识后,需要在脑子里形成知识体系。需要把自己以前学到的东西进行整理,形成一个体系,这篇文章讲解的是,深度学习中pytorch数据集的构造!!!
pytorch中有两个自定义管理数据集的类,
torch.utils.data.DataSet
torvchvision.datasets.ImageFolder
这里主要讲解的第一种。
DataSet源码
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
我们设计自己数据集类的时候, 只需要重写 __getitem__、__len__
两个函数,分别的功能是, 通过切片返回具样例、返回样本个数。
以下是voc2012数据集分割的例子:
import os
import numpy as np
from PIL import Image
from torch.utils import data
def read_images(root, train):
txt_fname = os.path.join(root, 'ImageSets/Segmentation/') + ('train.txt' if train else 'val.txt')
with open(txt_fname, 'r') as f:
images = f.read().split()
data = [os.path.join(root, 'JPEGImages', i + '.jpg') for i in images]
label = [os.path.join(root, 'SegmentationClass', i + '.png') for i in images]
return data, label
class VocSegDataset(data.Dataset):
def __init__(self, cfg, train, transforms=None):
self.cfg = cfg
self.train = train
self.transforms = transforms
self.data_list, self.label_list = read_images(self.cfg.DATASETS.ROOT, train)
def __getitem__(self, item):
img = self.data_list[item]
label = self.label_list[item]
img = Image.open(img)
# load label
label = Image.open(label)
img, label = self.transforms(img, label)
return img, label
def __len__(self):
return len(self.data_list)
通过上面的操作,我们构建自己数据集类,接下来,构建一个 Dataloader
类,这个作用是训练过程中,返回 batch个样例。
Dataloder源码
由于源码过于臃肿了,这里知识摘出对应的构造函数:
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=default_collate,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None):
构造函数中,每个参数的意思就不一一介绍了,只着重的讲解下,可调用函数collate_fn
。我们首先看一个构建Dataloader
的实例:
def build_dataset(cfg, transforms, is_train=True):
datasets = VocSegDataset(cfg, is_train, transforms)
return datasets
def make_data_loader(cfg, is_train=True):
if is_train:
batch_size = cfg.SOLVER.IMS_PER_BATCH
shuffle = True
else:
batch_size = cfg.TEST.IMS_PER_BATCH
shuffle = False
transforms = build_transforms(cfg, is_train)
datasets = build_dataset(cfg, transforms, is_train)
num_workers = cfg.DATALOADER.NUM_WORKERS
data_loader = data.DataLoader(
datasets, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True
)
return data_loader
上面第一个函数 build_dataset
返回数据集实例,第二个函数返回Dataloader
,关于Dataloader,我们需要注意的是,有时我们需要根据Dataset中的__getitem__
修改collate_fn
。
我们来看下源码:
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
我们在源码中发现,collate_fn
的输入是一个list,里面的每个元素是__getitem__
的输出,由此,我们估计,default_collate
的作用是将这个list,**变换格式为[batch,C,H,W]**的tensor,我们在来看下源码:
if.......
.........
elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'): # namedtuple
return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
由于源码均是对类型的判断,因此,这里我们知识摘出,与voc2012
分割相关的部分,这个语句的意思是, 对[(img1, label1), (img2, label2)],首先返回[img1,img2],[lable1,label2],在继续返回两个tensor,一个是img,[batch,C,H,W],一个是label:[batch,C,H,W]。
所以,通过上面分析,如果,我们__getitem__
不符合collat_fn
不符合默认函数的判断时,需要修改该函数。
好了,先到这,接下来…慢慢聊程序,需要学的太多了