Pytorch的数据读取主要包含三个类:
- Dataset
- DataLoader
- DataLoaderIter
这三者大致是一个依次封装的关系: 1.被装进2., 2.被装进3.
一. torch.utils.data.Dataset
是一个抽象类, 自定义的Dataset需要继承它并且实现两个成员方法:
__getitem__()
__len__()
第一个最为重要, 即每次怎么读数据. 以图片为例:
-
def __getitem__(self, index):
-
img_path, label =
self.data[index].img_path,
self.data[index].label
-
img = Image.open(img_path)
-
-
return img, label
值得一提的是, pytorch还提供了很多常用的transform, 在torchvision.transforms
里面, 本文中不多介绍, 常用的有Resize
, RandomCrop
, Normalize
, ToTensor
(这个极为重要, 可以把一个PIL或numpy图片转为torch.Tensor
, 但是好像对numpy数组的转换比较受限, 所以这里建议在__getitem__()
里面用PIL来读图片, 而不是用skimage.io).
第二个比较简单, 就是返回整个数据集的长度:
-
def __len__(self):
-
return len(
self.data)
二. torch.utils.data.DataLoader
类定义为:
-
class torch.utils.data.DataLoader(
-
dataset,
-
batch_size=
1,
-
shuffle=
False,
-
sampler=None,
-
batch_sampler=None,
-
num_workers=
0,
-
collate_fn=<
function default_collate>,
-
pin_memory=
False,
-
drop_last=
False
-
)
可以看到, 主要参数有这么几个:
dataset
: 即上面自定义的dataset.collate_fn
: 这个函数用来打包batchnum_worker
: 非常简单的多线程方法, 只要设置为>=1, 就可以多线程预读数据
这个类其实就是下面将要讲的DataLoaderIter
的一个框架, 一共干了两件事:
- 定义了一堆成员变量, 到时候赋给
DataLoaderIter
, - 然后有一个
__iter__()
函数, 把自己 "装进"DataLoaderIter
里面.
-
def __iter__(self):
-
return DataLoaderIter(
self)
三. torch.utils.data.dataloader.DataLoaderIter
上面提到, DataLoader
就是DataLoaderIter
的一个框架, 用来传给DataLoaderIter
一堆参数, 并把自己装进DataLoaderIter
里。其实到这里就可以满足大多数训练的需求了, 比如
-
class CustomDataset(Dataset):
-
# 自定义自己的dataset
-
-
dataset = CustomDataset()
-
dataloader = Dataloader(dataset, ...)
-
-
for data
in dataloader:
-
# training...
在for 循环里, 总共有三点操作:
- 调用了
dataloader
的__iter__()
方法, 产生了一个DataLoaderIter
- 反复调用
DataLoaderIter
的__next__()
来得到batch, 具体操作就是, 多次调用dataset的__getitem__()
方法 (如果num_worker
>0就多线程调用), 然后用collate_fn
来把它们打包成batch. 中间还会涉及到shuffle
, 以及sample
的方法等. - 当数据读完后,
__next__()
抛出一个StopIteration
异常,for
循环结束,dataloader
失效.
四. 又一层封装
其实上面三个类已经可以搞定了, 仅供参考
-
class DataProvider:
-
def __init__(self, batch_size, is_cuda):
-
self.batch_size = batch_size
-
self.dataset = Dataset_triple(
self.batch_size,
-
transform
_=transforms.Compose(
-
[transforms.Scale([
224,
224]),
-
transforms.ToTensor(),
-
transforms.Normalize(mean=[
0.
485,
0.
456,
0.
406],
-
std=[
0.
229,
0.
224,
0.
225])]),
-
)
-
self.is_cuda = is_cuda
# 是否将batch放到gpu上
-
self.dataiter = None
-
self.iteration =
0
# 当前epoch的batch数
-
self.epoch =
0
# 统计训练了多少个epoch
-
-
def build(self):
-
dataloader = DataLoader(
self.dataset, batch_size=
self.batch_size, shuffle=True, num_workers=
0, drop_last=True)
-
self.dataiter = DataLoaderIter(dataloader)
-
-
def next(self):
-
if
self.dataiter is
None:
-
self.build()
-
try:
-
batch =
self.dataiter.
next()
-
self.iteration +=
1
-
-
if
self.
is_cuda:
-
batch = [batch[
0].cuda(), batch[
1].cuda(), batch[
2].cuda()]
-
return batch
-
-
except
StopIteration:
# 一个epoch结束后reload
-
self.epoch +=
1
-
self.build()
-
self.iteration =
1
# reset and return the 1st batch
-
-
batch =
self.dataiter.
next()
-
if
self.
is_cuda:
-
batch = [batch[
0].cuda(), batch[
1].cuda(), batch[
2].cuda()]
-
return batch
感谢以下链接提供的参考:
转载:https://blog.csdn.net/weixin_39739616/article/details/83824944