用开源代码跑自己的数据集:修改dataloader

论文需要跑网络对比实验。那么如何用 Github 上的代码(或者其他开源代码)跑我们需要它跑的数据集呢

答:修改开源代码对应的 dataloader 部分即可。

下文将简要介绍与 PyTorch 框架的 dataloader 的相关知识。
首先引用 PyTorch 中文教程中关于 Dataset 抽象类的介绍和 Dataloader 的介绍 :

  • 我们在做深度学习训练时,首先要做的是做一个数据集类,它可能需要完成自动打乱数据数据处理批量提供 batchsize 数据等功能。 PyTorch 在 torch.utils.data 中提供了 Dataset抽象类,用于构建一个数据集类,可以对数据批量处理,可以构建一个数据集索引,PyTorch中的以方便批量训练数据时,方便调取。
  • 数据集创建完成后,我们可以对数据进行索引,但是还是无法实现批量获取数据,这时,我们就用到 DataLoader 去加载数据做一个数据加载器。

另外,在 PyTorch 官方的 Tutorial 中,我觉得有一句话很棒

The DataLoader combines the dataset and a sampler, returning an iterable over the dataset.

它指出了 DataLoader 本质上是一个 迭代器,而且同时由 dataset 和 sampler 组成。一语道破,妙不可言。

上文中关于 “数据加载器” 的概念,同时出现 dataloaderDataloader。因为后者是 PyTorch 提供的。通常使用的时候,我们对 Dataloader 的参数赋值,然后将 Dataloader 赋值给一个自己命名的 dataloader。如下所示:

train_loader = DataLoader(dataset = my_dataset,
                          batch_size = 32,
                          shuffle = True,
                          num_workers = 2)

接下来用更多示例代码做更详细的解释:

下面的代码 ex1,我专门把 from torch.utils.data import Datasetfrom torch.utils.data import DataLoader 写出来了,

为什么?

因为在写自己的类 MyDataset 的时候,类 MyDataset 要继承 PyTorch 的抽象类 Dataset。

另外,也用到了 PyTorch 的 DataLoader 来得到参数 batch_size 等赋值后的我们自己的 train_loader 。

### ex1
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class MyDataset(Dataset):

    # 初始化数据,比如从本地磁盘读入数据
    # 也可能要对数据进行标准化,裁剪等操作
    def __init__(self):
    
    
    # 返回数据集中的数据
    # 根据索引访问
    def __getitem__(self,index)
        return
    
    # 返回数据集的长度
    # 比如图像数据集中图像的数量
    def __len__(self):
        return
        
my_dataset = MyDataset()
train_loader = DataLoader(dataset = my_dataset,
                          batch_size = 32,
                          shuffle = True,
                          num_workers = 2)

下面用示例代码 ex2 来增加对类 MyDataset 的感觉:

ex2 这个代码的背景是要解决 分类 问题,代码数据的来源是 data.csv。当然在 init 函数中,还可以有其他一些代码,根据实际需求。比如 假设场景是 图像识别,那么在 init 函数中可能会有例如 ex3 的一段代码:

### ex2
class MyDataset(Dataset):

    # 初始化数据,比如从本地磁盘读入数据
    # 也可能要对数据进行标准化,裁剪等操作
    def __init__(self):
    xy = np.loadtxt('data.csv',delimiter=',',dtype=np.float32)
    self.len = xy.shape[0]
    self.data_input= torch.from_numpy(xy[:, 0:-1])
    self.label= torch.from_numpy(xy[:,[-1]])
    
    # 返回数据集中的数据
    # 根据索引访问
    def __getitem__(self,index)
        return self.data_input[index], self.label[index]
    
    # 返回数据集的长度
    # 比如图像数据集中图像的数量
    def __len__(self):
        return self.len
### ex3
from torchvision import transforms as T
class MyDataset(Dataset):

    # 初始化数据,比如从本地磁盘读入数据
    # 也可能要对数据进行标准化,裁剪等操作
    def __init__(self):
    	上文代码省略
    	transform = T.Compose([
    		T.Resize(112,112),
    		T.ToTensor(),    
   		 	T.Normalize(mean=[0.5], std=[0.5])
		])    
		
    def __getitem__(self,index)
        return 
    
    def __len__(self):
        return 

最后,由于在主函数中训练时,需要从 train_loader 遍历器中不停的取数据,再进行其他操作。如 ex4 所示的代码:

### ex4
for step, data in enumerate(train_loader):
    data_input, label = data

这样实现了一次遍历,若

epoach 不等于 1 的话,在外层加一个epoch循环。如 ex5所示的代码。如果对 epoach,batch size 和 iteration 的概念不清楚,可以移步至 此处

### ex5
for epoch in range(max_epoch):
    # 每个epoch
    model.train()
    for step, data in enumerate(train_loader):
        data_input, label = data

最后,我们在别人的开源代码中找到 两样东西。分别对应本文的示例代码中给 my_dataset 和 train_loader 赋值的来源,把它们的来源修改为我们自己数据集对应的。

参考

  • 本文得到了该视频的启发,该视频作者信息如下:
    PyTorch Zero To All Lecture by Sung Kim hunkim+ml@gmail.com at HKUST
    Code: https://github.com/hunkim/PyTorchZero…
    Slides: http://bit.ly/PyTorchZeroAll

  • PyTorch 中文教程:构建自己的数据集

  • 12
    点赞
  • 102
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
### 回答1: 在 PyTorch 中读取自定义数据集的一般步骤如下: 1. 定义数据集类:首先需要定义一个数据集类,继承自 `torch.utils.data.Dataset` 类,并实现 `__getitem__` 和 `__len__` 方法。在 `__getitem__` 方法中,根据索引返回一个样本的数据和标签。 2. 加载数据集:使用 `torch.utils.data.DataLoader` 类加载数据集,可以设置批量大小、多线程读取数据等参数。 下面是一个简单的示例代码,演示如何使用 PyTorch 读取自定义数据集: ```python import torch from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y def __len__(self): return len(self.data) # 加载训练集和测试集 train_data = ... train_targets = ... train_dataset = CustomDataset(train_data, train_targets) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_data = ... test_targets = ... test_dataset = CustomDataset(test_data, test_targets) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 训练模型 for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(train_loader): # 前向传播、反向传播,更新参数 ... ``` 在上面的示例代码中,我们定义了一个 `CustomDataset` 类,加载了训练集和测试集,并使用 `DataLoader` 类分别对它们进行批量读取。在训练模型时,我们可以像使用 PyTorch 自带的数据集一样,循环遍历每个批次的数据和标签,进行前向传播、反向传播等操作。 ### 回答2: PyTorch是一个源的深度学习框架,它提供了丰富的功能用于读取和处理自定义数据集。下面是一个简单的步骤来读取自定义数据集。 首先,我们需要定义一个自定义数据集类,该类应继承自`torch.utils.data.Dataset`类,并实现`__len__`和`__getitem__`方法。`__len__`方法应返回数据集的样本数量,`__getitem__`方法根据给定索引返回一个样本。 ```python import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] return torch.tensor(sample) ``` 接下来,我们可以创建一个数据集实例并传入自定义数据。假设我们有一个包含多个样本的列表 `data`。 ```python data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] dataset = CustomDataset(data) ``` 然后,我们可以使用`torch.utils.data.DataLoader`类加载数据集,并指定批次大小、是否打乱数据等。 ```python batch_size = 2 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) ``` 现在,我们可以迭代数据加载器来获取批次的样本。 ```python for batch in dataloader: print(batch) ``` 上面的代码将打印出两个批次的样本。如果`shuffle`参数设置为`True`,则每个批次的样本将是随机的。 总而言之,PyTorch提供了简单而强大的工具来读取和处理自定义数据集,可以根据实际情况进行适当修改和扩展。 ### 回答3: PyTorch是一个流行的深度学习框架,可以用来训练神经网络模型。要使用PyTorch读取自定义数据集,可以按照以下几个步骤进行: 1. 准备数据集:将自定义数据集组织成合适的目录结构。通常情况下,可以将数据集分为训练集、验证集和测试集,每个集合分别放在不同的文件夹中。确保每个文件夹中的数据按照类别进行分类,以便后续的标签处理。 2. 创建数据加载器:在PyTorch中,数据加载器是一个有助于有效读取和处理数据的类。可以使用`torchvision.datasets.ImageFolder`类创建一个数据加载器对象,通过传入数据集的目录路径来实现。 3. 数据预处理:在将数据传入模型之前,可能需要对数据进行一些预处理操作,例如图像变换、标准化或归一化等。可以使用`torchvision.transforms`中的类来实现这些预处理操作,然后将它们传入数据加载器中。 4. 创建数据迭代器:数据迭代器是连接数据集和模型的重要接口,它提供了一个逐批次加载数据的功能。可以使用`torch.utils.data.DataLoader`类创建数据迭代器对象,并设置一些参数,例如批量大小、是否打乱数据等。 5. 使用数据迭代器:在训练时,可以使用Python的迭代器来遍历数据集并加载数据。通常,它会在每个迭代步骤中返回一个批次的数据和标签。可以通过`for`循环来遍历数据迭代器,并在每个步骤中处理批次数据和标签。 这样,我们就可以在PyTorch中成功读取并处理自定义数据集。通过这种方式,我们可以更好地利用PyTorch的功能来训练和评估自己的深度学习模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

培之

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值