首先import torch.utils.data
训练模型前的第一件事情就是导入训练时所需要用到的数据。
首先定义数据类
形式为下图代码所示,大概是这个形式。(根据自身需要用到的数据集进行调整)
import torch.utils.data
class MyDataset(torch.utils.data.Dataset):
def __init__(self,train_data,test_data):
super(MyDataset,self).__init__()
self.train_data=train_data
self.test_data=test_data
def __getitem(self,item):
return self.train_data(item)
def __len__(self):
return self.train_data.shape()
我们自定义的类要继承torch.utils.data中Dataset这个父类,在init初始化方法中采用super()这个特殊函数,super函数里必须要包含两个参数,分别是子类名和参数self,这样你的自定义数据类就可继承Dataset父类的方法。对于Dataset类,查看类的声明可以得知,必须重写 getitem 和 len 这两个函数。
定义完MyDataset后使用torch.utils.data.DataLoader。DataLoader是pytorch中读取数据的一个重要接口,基本上用pytorch训练都会用到。这个接口的目的是将自定义的Dataset根据batch size大小,是否shuffle等选项封装成一个batch size大小的tensor。
处理过程如下:
dataset=MyDataset()
loader=torch.utils.data.DataLoader(dataset,batch_size=2,shuffle=True,...)
以下是DataLoader初始化的参数:(按住ctrl 点击类名进行查询,自行根据所需调整相应的参数即可)
关于DataLoader更深一步的了解可跳转这篇文章:http://t.csdn.cn/MiRdH