Pytorch的Dataset与DataLoader用于自定义数据集和梯度训练。
首先,定义自己的数据集需要继承Dataset
from torch.utils.data import Dataset,DataLoader
class GetData(Dataset):
#初始化需要定义数据集
def __init__(self,datas,labels):
self.data = datas
self.label = labels
#迭代获取数据
def __getitem__(self,index):
data = self.data[index]
label = self.label[index]
return data,label
def __len__(self):
return len(self.data)
#随机生成测试集
import numpy as np
source_data = np.random.rand(10,20)
source_label = np.random.rand(0,2,(10,1))
myData = GetData(source_data,source_label)
其次,构建DataLoader对象
myDataLoader = DataLoader(myData,batch_size=5,shuffle=True,drop_last=False,num_workers=2)
for i,data in enumerate(myDataLoader):
print('第{}个batch \n'.format(i,data))
#迭代
EPOCHS=5
for epoch in range(EPOCHS):
for i , data in enumerate(myDataLoader):
print('第{}个batch \n'.format(i,data))