打印DataLoader类的图片(CIFAR-10)

目标概述:图片已经传入DataLoader类中了,如何通过迭代DataLoader对象,将其中包含的图片打印出来并保存。

1.DataLoader对象创建过程

首先要了解DataLoader对象是如何创建的,才能理解如何将其中图片打印出来

简单概括,创建DataLoader对象步骤为:

        ①用datasets.CIFAR10加载训练集/测试集,这里对数据集进行了正则化

        ②用torch.utils.data.DataLoader对数据集封装,获得DataLoader对象train_loader


     
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        normalize = transforms.Normalize(mean=self.mean, std=self.std)
        normalized = transforms.Compose([transforms.ToTensor(), normalize])

        trainset =  datasets.CIFAR10(root='/home/c01yili/datasets/common_dataset', train=True, download=True, transform=self.normalized)
        testset =  datasets.CIFAR10(root='/home/c01yili/datasets/common_dataset', train=False, download=True, transform=self.normalized)

        
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=0)


 

2.输出DataLoader对象中的图片

①通过循环遍历train_loader

②通过squeeze函数将data的形状从(1,3,32,32)转化为(3,32,32)

③通过transpose函数将data的形状从(3,32,32)转化为(32,32,3),便于后续图像处理

④由于DataLoader对象创建过程中进行了正则化,因此这里需要对进行反正则化操作

⑤将数据类型从float转化为uint8,这一步没有的话图片输出是不正确的

⑥保存图片

    with torch.no_grad():
        for i, (data, target, ori_idx) in enumerate(train_loader):
            data = data.cpu().detach().numpy()
            data = np.squeeze(data)
            data = np.transpose(data, (1, 2, 0))  # 把channel那一维放到最后,(3,32,32)--->(32,32,3)
            # 反Normalize操作
            data = (data * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            plt.imshow(data.astype('uint8'))
            plt.axis('off')
            dir = "./overview/color/" + str(i) + ".png"
            plt.savefig(dir, dpi=1000, bbox_inches='tight', pad_inches=0)
            plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值