import torch
from torch.utils import data
import numpy as np
# 继承Dataset,自定义数据集及标签
class TestData(data.Dataset):
def __init__(self):
self.Data = np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])
self.Label = np.asarray([0,1,0,1,2])
def __getitem__(self,index):
txt = torch.from_numpy(self.Data[index])
label=torch.tensor(self.Label[index])
return txt,label
def __len__(self):
return len(self.Data)
#实例化类Test
Test = TestData()
# 打印索引为4的数据集及标签
print(Test[4])
# 打印数据集长度
print(Test.__len__())
# 批量获取数据集,对应参数为batch_size
test_loader = data.DataLoader(Test,batch_size=2,shuffle=False,num_workers=0)
for i,traindata in enumerate(test_loader):
print('i:',i)
Data,Label = traindata
print('data:',Data)
print('Label:',Label)
在上述程序中需要注意DataLoader迭代器num_worker参数,该参数为进程数,在作者电脑上参数仅能为0的情况下才能运行,进程通常理解一个程序,打开爱奇艺和优酷应用即可理解为两个进程,不清楚为什么设置为2不能运行,也许依附GPU。后续查阅资料后进行补充说明。