二、如何保存MNIST数据集中train和test的图片?

如何保存MNIST数据集中train和test的图片?

介绍一种非诚神奇的图片保存方法,尤其是利用字典…format…结合来用,创建保存路径,这是一种史上很难用到的一种方法,哈哈哈哈,有点吹牛皮,不说了,言归正传。请仔细看!

// An highlighted block
from torchvision import datasets, transforms
import os

if __name__ == '__main__':
    train_data = datasets.MNIST("./data", train=True, transform=transforms.ToTensor(), download=True)  # 读取train数据
    test_data = datasets.MNIST("./data", train=False, transform=transforms.ToTensor(), download=False)  # 读取test数据
    pic_dict = {i: 0 for i in range(10)}  # 创建字典,以便将MNIST数据集0-9类按类加入,且不重名。
    # #for i, (image,label) in enumerate(test_data):#取出图片和label
    for i, data in enumerate(test_data):
        image = data[0]  # shape=6000*28*28
        label = data[1]  # shape=6000,0-910类
        img = transforms.ToPILImage()(image)  # 转化成张量,即变成tensor形式
        if os.path.exists(f'./test_img/{label}'):
            pass
        else:
            os.makedirs(f'./test_img/{label}')
        # img.save('./test_img/{}/{}.png'.format(label, pic_dict[label]))
        img.save(f'./test_img/{label}/{pic_dict[label]}.png')  # 保存路径
        pic_dict[label] += 1
    print(sum(pic_dict.values()))  # pic_dict.values()计算键值总和

上述方法局限于保存一个数据集的图片,也就是说,要么保存训练集,要么保存测试集的图片,哈哈哈,不要着急,现在有一种方法,一步登天,接着看…下面代码…
下面展示代码。

from torchvision import datasets, transforms
import os
if __name__ == '__main__':
    train_data = datasets.MNIST('./data', train=True,transform=transforms.ToTensor(),download=False)
    test_data = datasets.MNIST('./data', train=False,transform=transforms.ToTensor(),download=False)
    for i in [train_data,test_data]:#将两个数据集装在一起,嵌入一个循环。
        # 创建字典
        pic_dic = {i: 0 for i in range(10)}
        # 取出图片和标签
        if i ==train_data:
            for i, (image, label) in enumerate(train_data):
                img = transforms.ToPILImage()(image)  # 转为张量
                if not os.path.exists(f'./train_img/{label}'):
                    os.makedirs(f'./train_img/{label}')
                else:
                    pass
                img.save(f'./train_img/{label}/{pic_dic[label]}.png')  # 保存路径
                pic_dic[label] += 1
                sum_values=sum(pic_dic.values())
            print('训练集图片合计:%s张'%sum_values)  # 求键值的总和
        else:
            for i, (image, label) in enumerate(test_data):
                img = transforms.ToPILImage()(image)  # 转为张量
                if not os.path.exists(f'./test_img/{label}'):
                    os.makedirs(f'./test_img/{label}')
                else:
                    pass
                img.save(f'./test_img/{label}/{pic_dic[label]}.png')  # 保存路径
                pic_dic[label] += 1
                sum_values=sum(pic_dic.values())
            print('测试集图片合计:%s张'%sum_values)  # 求键值的总和

还有很多方法,比如PIL、OpenCV也可以实现,但是我不知道,哈哈哈哈…

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值