目标概述:图片已经传入DataLoader类中了,如何通过迭代DataLoader对象,将其中包含的图片打印出来并保存。
1.DataLoader对象创建过程
首先要了解DataLoader对象是如何创建的,才能理解如何将其中图片打印出来
简单概括,创建DataLoader对象步骤为:
①用datasets.CIFAR10加载训练集/测试集,这里对数据集进行了正则化
②用torch.utils.data.DataLoader对数据集封装,获得DataLoader对象train_loader
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=self.mean, std=self.std)
normalized = transforms.Compose([transforms.ToTensor(), normalize])
trainset = datasets.CIFAR10(root='/home/c01yili/datasets/common_dataset', train=True, download=True, transform=self.normalized)
testset = datasets.CIFAR10(root='/home/c01yili/datasets/common_dataset', train=False, download=True, transform=self.normalized)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=0)
2.输出DataLoader对象中的图片
①通过循环遍历train_loader
②通过squeeze函数将data的形状从(1,3,32,32)转化为(3,32,32)
③通过transpose函数将data的形状从(3,32,32)转化为(32,32,3),便于后续图像处理
④由于DataLoader对象创建过程中进行了正则化,因此这里需要对进行反正则化操作
⑤将数据类型从float转化为uint8,这一步没有的话图片输出是不正确的
⑥保存图片
with torch.no_grad():
for i, (data, target, ori_idx) in enumerate(train_loader):
data = data.cpu().detach().numpy()
data = np.squeeze(data)
data = np.transpose(data, (1, 2, 0)) # 把channel那一维放到最后,(3,32,32)--->(32,32,3)
# 反Normalize操作
data = (data * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
plt.imshow(data.astype('uint8'))
plt.axis('off')
dir = "./overview/color/" + str(i) + ".png"
plt.savefig(dir, dpi=1000, bbox_inches='tight', pad_inches=0)
plt.show()