PyTorch(一) DataSet and DataLoader
目录
一、DataSet
1.1、DataSet 简介
DataSet 和 DataLoader两个部分是数据封装成以 Batch 形式送入模型的关键。这篇文章将这部分分成两步进行介绍,第一块先介绍一下DataSet类。
导入 DataSet 类的方式:
from torch.utils.data import Dataset
DataSet 是一个抽象的数据封装类,第一步需要自定义一个继承 DataSet 的数据类用于自定义数据形式,对于 DataSet 中最重要的两个需要重写的函数就是 __getitem__() 以及 __len__()。
__len__(): 这个函数是用于获取当前数据集的大小。
__getitem__(): 这个函数极为重要,这个函数返回的每一条数据形式就是你的 DataLoader 接收到的数据形式,和之后DataLoader 中的 collate_fn 的定义息息相关。
1.2、DataSet 代码实现
class MyData(Dataset):
def __init__(self, data, target):
self.data = data
self.target = target
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data = self.data[index]
target = self.target[index]
return [data, target]
首先这个继承类 MyData 中实现的就是__init__(),__len__(),__getitem__() 这个三个函数。
__init__(): 就是将传进来的数据以及标签进行赋值操作
__len__(): 这个函数就是返回整个数据集的大小操作
__getitem__(): 这个函数决定之后Dataloader遍历的时候每一次获取到的数据类型是什么样子的
二、DataLoader
2.1、DataLoader 简介
导入 DataLoader 的方式
from torch.utils.data import DataLoader
2.2、DataLoader 源码介绍
2.3、DataLoader 代码实现
三、DataSet 与 DataLoader 的相关性