在该系列的上一篇,我们讲解了计算图和自动求导的知识点,这个内容是Pytorch的基础也是重点,如果不记得了,回去看看吧~
我们本篇聊聊Pytorch中的Dataset类。
在进行深度学习的时候,最重要的是什么?没错,就是数据!数据的形式多种多样,可以是文本,可以是表格数据,可以是声音,可以是图像,甚至视频。当我们手上有了数据,接下来的步骤就是将数据读取处理给模型使用,Pytorch提供了很多工具,能让我们读取数据和预处理数据变得easy!
Pytorch的Dataset类是一个抽象类,源码如下,其内部有三个魔法方法:
class Dataset(object):"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""def __getitem__(self, index): raise NotImplementedErrordef __len__(self): raise NotImplementedErrordef __add__(self, other): return ConcatDataset([self, other])
当我们加载数据时,可以定义子类继承Dataset类,定义的子类需要重载两个方法,分别是:
__len__方法,用来提供数据库的大小。__getitem__方法,支持一个整形索引,重来获取单个数据,范围是__len__定义的,范围是[0, len(self)]
例如我们可以定义自己的数据类,继承和重写这个抽象类,例如:
import torchimport pandas as pdfrom torch.utils.data import Dataset
我们已经导入必要的模块,然后可以按你的需要继承重写:
class myDataset(Dataset):def __init__(self, csv_file, root_dir): self.csv_data = csv_file self.root_dir = root_dirdef __len__(self): return len(self.csv_data)def __getitem__(self, idx): data = (self.csv_data[idx]) return datadef read(self, csv, index): return pd.read_csv(csv[index])
在上面的代码中,我们在初始化中写了个文件的名称和路径,调用__len__方法可以获取传入数据的个数,而__getitem__可以根据索引获取传入数据的名称。使用read方法可以打开传入的数据。
csv_file = [r'F:/train.csv',r'F:/test.csv']root_dir = r'F/'ds1 = myDataset(csv_file,root_dir)
ds1[0]
可以得到:
len(ds1)
得到 2,可以得出一共有两个文件
ds1.read(ds1,1)
可以读取csv的内容:
前文传送门:
从零开始深度学习Pytorch笔记(1)——安装Pytorch 从零开始深度学习Pytorch笔记(2)——张量的创建(上) 从零开始深度学习Pytorch笔记(3)——张量的创建(下) 从零开始深度学习Pytorch笔记(4)——张量的拼接与切分 从零开始深度学习Pytorch笔记(5)——张量的索引与变换 从零开始深度学习Pytorch笔记(6)——张量的数学运算 从零开始深度学习Pytorch笔记(7)—— 使用Pytorch实现线性回归 从零开始深度学习Pytorch笔记(8)—— 计算图与自动求导(上)从零开始深度学习Pytorch笔记(9)—— 计算图与自动求导(下)
关注我们
更多开发者活动及技术讯息
请关注微软中国MSDN公众号
▲▲