本段python代码用于读取MNIST手写数字识别数据集并显示图片,并可以选择储存图片到本地磁盘
import torch
import matplotlib.pyplot as plt
import struct
import os
def save_mnist_pic(image_dir,label_dir,save_img_count,store_dir='.\\img\\'):
with open(image_dir,'rb') as fp:
image_info=struct.unpack('>IIII',fp.read(4*4))
if image_info[0] != 0x0803:
print("not a mnist file!")
return None,None
print("Image_Count:%d Solution:%d*%d"%(image_info[1],image_info[2],image_info[3]))
raw_image = fp.read()
pic = torch.zeros(image_info[1],image_info[2],image_info[3])
for i in range(image_info[1]):
x = raw_image[i*28*28:(i+1)*28*28]
x = torch.tensor(list(x)).reshape(28,28)
pic[i,:,:]=x
with open(label_dir,'rb') as fp:
label_info=struct.unpack('>II',fp.read(4*2))
if label_info[0] != 0x0801:
print("not a label file!")
return
print("Label_Count:%d"%(label_info[1]))
if label_info[1] != image_info[1]:
print("the label count is not same as the image count!")
return None,None
raw_label = fp.read()
label = list(raw_label)
os.makedirs(store_dir,exist_ok=True)
for i in range(save_img_count):
plt.imsave(store_dir+str(i)+'.jpg',pic[i],cmap='Greys')
return pic,label
pic,label= save_mnist_pic("./train-images.idx3-ubyte","./train-labels.idx1-ubyte",0)
plt.title(str(label[0]),color='r')
plt.axis('off')
plt.imshow(pic[0],cmap='Greys')
plt.show()