shuffle=True用于打乱数据集,每次都会以不同的顺序返回。
from torch.utils.data import Dataset, DataLoader
class DataSet(Dataset):
def __init__(self, n):
self.n = n
self.data = [i for i in range(n)]
def __len__(self):
return self.n
def __getitem__(self, idx):
return self.data[idx]
data_set = DataSet(32)
data_loader = DataLoader(data_set, batch_size=8, shuffle=True)
for num in data_loader:
print(num, end='\t')
print()
for num in data_loader:
print(num, end='\t')
print()
for num in data_loader:
print(num, end='\t')
实验结果:
tensor([ 7, 23, 19, 4, 22, 26, 21, 31]) tensor([27, 14, 13, 12, 2, 24, 5, 10]) tensor([ 0, 15, 16, 30, 25, 8, 6, 29]) tensor([11, 3, 1, 18, 9, 20, 17, 28])
tensor([10, 11, 22, 28, 24, 19, 31, 8]) tensor([15, 2, 9, 1, 20, 14, 23, 16]) tensor([ 5, 25, 4, 6, 21, 30, 18, 27]) tensor([ 3, 13, 17, 12, 29, 0, 7, 26])
tensor([ 1, 2, 13, 20, 11, 19, 9, 22]) tensor([ 0, 14, 25, 24, 27, 31, 12, 28]) tensor([10, 4, 23, 15, 21, 30, 6, 16]) tensor([ 7, 17, 8, 26, 3, 29, 18, 5])
如果是shuffle=False
的话,实验结果
tensor([0, 1, 2, 3, 4, 5, 6, 7]) tensor([ 8, 9, 10, 11, 12, 13, 14, 15]) tensor([16, 17, 18, 19, 20, 21, 22, 23]) tensor([24, 25, 26, 27, 28, 29, 30, 31])
tensor([0, 1, 2, 3, 4, 5, 6, 7]) tensor([ 8, 9, 10, 11, 12, 13, 14, 15]) tensor([16, 17, 18, 19, 20, 21, 22, 23]) tensor([24, 25, 26, 27, 28, 29, 30, 31])
tensor([0, 1, 2, 3, 4, 5, 6, 7]) tensor([ 8, 9, 10, 11, 12, 13, 14, 15]) tensor([16, 17, 18, 19, 20, 21, 22, 23]) tensor([24, 25, 26, 27, 28, 29, 30, 31])