# -*- coding: utf-8 -*-
import numpy as np
import struct
import matplotlib.pyplot as plt
import os
filename = 'data_AI/MNIST/train-images.idx3-ubyte'
binfile = open(filename , 'rb')
buf = binfile.read()
index = 0
magic, numImages , numRows , numColumns = struct.unpack_from('>IIII' , buf , index)
index += struct.calcsize('IIII' )
images = []
for i in range(numImages):
imgVal = struct.unpack_from('>784B', buf, index)
index += struct.calcsize('>784B')
imgVal = list(imgVal)
for j in range(len(imgVal)):
if imgVal[j] > 1:
imgVal[j] = 1
images.append(imgVal)
arrX = np.array(images)
# 读取标签
binFile = open('data_AI/MNIST/train-labels.idx1-ubyte','rb')
buf = binFile.read()
binFile.close()
index = 0
magic, numItems= struct.unpack_from('>II', buf,index)
index += struct.calcsize('>II')
labels = []
for x in range(numItems):
im = struct.unpack_from('>1B',buf,index)
index += struct.calcsize('>1B')
labels.append(im[0])
arrY = np.array(labels)
print(np.shape(arrY))
# print(np.shape(trainX))
#以下内容是将图像保存到本地文件中
path_trainset = "data_AI/MNIST/imgs_train"
path_testset = "data_AI/MNIST/imgs_test"
if not os.path.exists(path_trainset):
os.mkdir(path_trainset)
if not os.path.exists(path_testset):
os.mkdir(path_testset)
for i in range(1):
img = np.array(arrX[i])
print(img)
img = img.reshape(28,28)
outfile = str(i) + "_" + str(arrY[i]) + ".png"
# outfile = str(i)+".png"
plt.figure()
plt.imshow(img, cmap = 'binary') #将图像黑白显示
plt.savefig(path_trainset + "/" + outfile)
print("save"+str(i)+"张")