Pytorch
的DataLoader
不能够直接达到这个效果,必须要借助DataSet
来实现
DataSet
的用法可以参考:pytorch 构造读取数据的工具类 Dataset 与 DataLoader (pytorch Data学习一)
示例代码
from torch.utils.data import Dataset, DataLoader
import numpy as np
class MyDataSet(Dataset):
def __init__(self):
sample = 20000 # 数据量
self.data_1 = np.random.randn(sample) # 数据集1
self.data_2 = np.random.randn(sample) # 数据集2
self.data_3 = np.random.randn(sample) # 数据集3
self._len = sample # 必要,定义最大循环次数,一般也是全部的数据量
def __getitem__(self, item: int): # 这个item即为下标,整数
# 每次循环的时候返回的值
return self.data_1[item], self.data_2[item], self.data_3[item]
def __len__(self):
return self._len
if __name__ == '__main__':
data = MyDataSet()
dataloader = DataLoader(data, batch_size=3, shuffle=False, num_workers=0) # 这里的batch_size
n = 0
for data_1, data_2, data_3 in dataloader:
print("迭代{}次".format(n), data_1.numpy(), data_2.numpy(), data_3.numpy())
n += 1
在DataSet
的__getitem__
函数中,根据下标item取到数即可。在DataLoader
中,batch_size定为多少,每次取数时就会循环多少次__getitem__
,然后一并打包取出来。