如何保存MNIST数据集中train和test的图片?
介绍一种非诚神奇的图片保存方法,尤其是利用字典…format…结合来用,创建保存路径,这是一种史上很难用到的一种方法,哈哈哈哈,有点吹牛皮,不说了,言归正传。请仔细看!
// An highlighted block
from torchvision import datasets, transforms
import os
if __name__ == '__main__':
train_data = datasets.MNIST("./data", train=True, transform=transforms.ToTensor(), download=True) # 读取train数据
test_data = datasets.MNIST("./data", train=False, transform=transforms.ToTensor(), download=False) # 读取test数据
pic_dict = {i: 0 for i in range(10)} # 创建字典,以便将MNIST数据集0-9类按类加入,且不重名。
# #for i, (image,label) in enumerate(test_data):#取出图片和label
for i, data in enumerate(test_data):
image = data[0] # shape=6000*28*28
label = data[1] # shape=6000,0-9共10类
img = transforms.ToPILImage()(image) # 转化成张量,即变成tensor形式
if os.path.exists(f'./test_img/{label}'):
pass
else:
os.makedirs(f'./test_img/{label}')
# img.save('./test_img/{}/{}.png'.format(label, pic_dict[label]))
img.save(f'./test_img/{label}/{pic_dict[label]}.png') # 保存路径
pic_dict[label] += 1
print(sum(pic_dict.values())) # pic_dict.values()计算键值总和
上述方法局限于保存一个数据集的图片,也就是说,要么保存训练集,要么保存测试集的图片,哈哈哈,不要着急,现在有一种方法,一步登天,接着看…下面代码…
下面展示代码。
from torchvision import datasets, transforms
import os
if __name__ == '__main__':
train_data = datasets.MNIST('./data', train=True,transform=transforms.ToTensor(),download=False)
test_data = datasets.MNIST('./data', train=False,transform=transforms.ToTensor(),download=False)
for i in [train_data,test_data]:#将两个数据集装在一起,嵌入一个循环。
# 创建字典
pic_dic = {i: 0 for i in range(10)}
# 取出图片和标签
if i ==train_data:
for i, (image, label) in enumerate(train_data):
img = transforms.ToPILImage()(image) # 转为张量
if not os.path.exists(f'./train_img/{label}'):
os.makedirs(f'./train_img/{label}')
else:
pass
img.save(f'./train_img/{label}/{pic_dic[label]}.png') # 保存路径
pic_dic[label] += 1
sum_values=sum(pic_dic.values())
print('训练集图片合计:%s张'%sum_values) # 求键值的总和
else:
for i, (image, label) in enumerate(test_data):
img = transforms.ToPILImage()(image) # 转为张量
if not os.path.exists(f'./test_img/{label}'):
os.makedirs(f'./test_img/{label}')
else:
pass
img.save(f'./test_img/{label}/{pic_dic[label]}.png') # 保存路径
pic_dic[label] += 1
sum_values=sum(pic_dic.values())
print('测试集图片合计:%s张'%sum_values) # 求键值的总和
还有很多方法,比如PIL、OpenCV也可以实现,但是我不知道,哈哈哈哈…