论文需要跑网络对比实验。那么如何用 Github 上的代码(或者其他开源代码)跑我们需要它跑的数据集呢?
答:修改开源代码对应的 dataloader 部分即可。
下文将简要介绍与 PyTorch 框架的 dataloader 的相关知识。
首先引用 PyTorch 中文教程中关于 Dataset 抽象类的介绍和 Dataloader 的介绍 :
- 我们在做深度学习训练时,首先要做的是做一个数据集类,它可能需要完成自动打乱数据、数据处理、批量提供 batchsize 数据等功能。 PyTorch 在 torch.utils.data 中提供了 Dataset 的抽象类,用于构建一个数据集类,可以对数据批量处理,可以构建一个数据集索引,PyTorch中的以方便批量训练数据时,方便调取。
- 数据集创建完成后,我们可以对数据进行索引,但是还是无法实现批量获取数据,这时,我们就用到 DataLoader 去加载数据做一个数据加载器。
另外,在 PyTorch 官方的 Tutorial 中,我觉得有一句话很棒:
The DataLoader combines the dataset and a sampler, returning an iterable over the dataset.
它指出了 DataLoader 本质上是一个 迭代器,而且同时由 dataset 和 sampler 组成。一语道破,妙不可言。
上文中关于 “数据加载器” 的概念,同时出现 dataloader 和 Dataloader。因为后者是 PyTorch 提供的。通常使用的时候,我们对 Dataloader 的参数赋值,然后将 Dataloader 赋值给一个自己命名的 dataloader。如下所示:
train_loader = DataLoader(dataset = my_dataset,
batch_size = 32,
shuffle = True,
num_workers = 2)
接下来用更多示例代码做更详细的解释:
下面的代码 ex1,我专门把 from torch.utils.data import Dataset
与from torch.utils.data import DataLoader
写出来了,
为什么?
因为在写自己的类 MyDataset 的时候,类 MyDataset 要继承 PyTorch 的抽象类 Dataset。
另外,也用到了 PyTorch 的 DataLoader 来得到参数 batch_size 等赋值后的我们自己的 train_loader 。
### ex1
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class MyDataset(Dataset):
# 初始化数据,比如从本地磁盘读入数据
# 也可能要对数据进行标准化,裁剪等操作
def __init__(self):
# 返回数据集中的数据
# 根据索引访问
def __getitem__(self,index)
return
# 返回数据集的长度
# 比如图像数据集中图像的数量
def __len__(self):
return
my_dataset = MyDataset()
train_loader = DataLoader(dataset = my_dataset,
batch_size = 32,
shuffle = True,
num_workers = 2)
下面用示例代码 ex2 来增加对类 MyDataset 的感觉:
ex2 这个代码的背景是要解决 分类 问题,代码数据的来源是 data.csv。当然在 init 函数中,还可以有其他一些代码,根据实际需求。比如 假设场景是 图像识别,那么在 init 函数中可能会有例如 ex3 的一段代码:
### ex2
class MyDataset(Dataset):
# 初始化数据,比如从本地磁盘读入数据
# 也可能要对数据进行标准化,裁剪等操作
def __init__(self):
xy = np.loadtxt('data.csv',delimiter=',',dtype=np.float32)
self.len = xy.shape[0]
self.data_input= torch.from_numpy(xy[:, 0:-1])
self.label= torch.from_numpy(xy[:,[-1]])
# 返回数据集中的数据
# 根据索引访问
def __getitem__(self,index)
return self.data_input[index], self.label[index]
# 返回数据集的长度
# 比如图像数据集中图像的数量
def __len__(self):
return self.len
### ex3
from torchvision import transforms as T
class MyDataset(Dataset):
# 初始化数据,比如从本地磁盘读入数据
# 也可能要对数据进行标准化,裁剪等操作
def __init__(self):
上文代码省略
transform = T.Compose([
T.Resize(112,112),
T.ToTensor(),
T.Normalize(mean=[0.5], std=[0.5])
])
def __getitem__(self,index)
return
def __len__(self):
return
最后,由于在主函数中训练时,需要从 train_loader 遍历器中不停的取数据,再进行其他操作。如 ex4 所示的代码:
### ex4
for step, data in enumerate(train_loader):
data_input, label = data
这样实现了一次遍历,若
epoach 不等于 1 的话,在外层加一个epoch循环。如 ex5所示的代码。如果对 epoach,batch size 和 iteration 的概念不清楚,可以移步至 此处 。
### ex5
for epoch in range(max_epoch):
# 每个epoch
model.train()
for step, data in enumerate(train_loader):
data_input, label = data
最后,我们在别人的开源代码中找到 两样东西。分别对应本文的示例代码中给 my_dataset 和 train_loader 赋值的来源,把它们的来源修改为我们自己数据集对应的。
参考
-
本文得到了该视频的启发,该视频作者信息如下:
PyTorch Zero To All Lecture by Sung Kim hunkim+ml@gmail.com at HKUST
Code: https://github.com/hunkim/PyTorchZero…
Slides: http://bit.ly/PyTorchZeroAll