DataLoader可以将datasets中的数据打包抽取,生成batch_size便于深度学习训练。
具体用法如下:
import torchvision.datasets
from torch.utils.data import DataLoader
# 由于测试数据集数量较少,这里创建CIFAR10中的测试数据
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
test_set = torchvision.datasets.CIFAR10('./dataset', train=False, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset=test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# dataset= 为选择数据集, bach_size = 为一次抓取多少个数据, shuffle为每个epoch是否打乱顺序,num_worders= 为设置并行计算个数
# drop_last= 为在数据集无法被bach_size整除时,剩下的数据是否抛弃,True为抛弃
writer = SummaryWriter('Data_loader')
# 对象调用和多个图片展示,通过打开CIFAR10的__getitem__函数可以知道返回的值为img和target
# 注意:data_loader不能用索引的方式返回值
indx = 1
for data in data_loader:
imgs, targets = data
writer.add_images('test', imgs, global_step=indx) # 这里添加多个图片用add_images
indx += 1
# 2次epoch
for epoch in range(2):
indx = 1
for data_1 in data_loader:
imgs, targets = data_1
writer.add_images('Epoch: {}'.format(epoch), imgs, global_step=indx) # 这里添加多个图片用add_images
indx += 1
# 这里用到格式化输出,即‘this is {}’.format()。 {}中的内容添加值
writer.close()
注意:
1、DataLoader实例化的对象不能用索引进行调用,可以用for循环调用,但是datasets可以;
2、这里用到格式化输出,例如‘this is {}’.format(a), a的值将被放在{}中。
3、shuffle用于将每次epoch的值打乱,drop_last将最后不能整除的剩余数据舍弃(drop_last=True)。