转载请注明出处https://blog.csdn.net/EatAppleS/article/details/90172847
数据库的下载地址:http://yann.lecun.com/exdb/mnist/
有4个文件,训练集的图片和标签,
train-images.idx3-ubyte
train-labels.idx1-ubyte
测试集的图片和标签
t10k-images.idx3-ubyte
t10k-labels.idx1-ubyte
数据库及解析代码的下载地址如下:https://download.csdn.net/download/eatapples/11175660
解析代码如下:
import numpy as np
import struct
import cv2
train_images_idx3_ubyte_file =r'./train-images.idx3-ubyte'
train_labels_idx1_ubyte_file = r'./train-labels.idx1-ubyte'
test_images_idx3_ubyte_file =r'./t10k-images.idx3-ubyte'
test_labels_idx1_ubyte_file =r'./t10k-labels.idx1-ubyte'
def decode_images(imgPath):
bin_data = open(imgPath, 'rb').read()
offset = 0
magic_number, num_images, num_rows, num_cols = struct.unpack_from('>iiii', bin_data, offset)
print('img num %d img rows %d img cols %d' % (num_images, num_rows, num_cols))
image_size = num_rows * num_cols
offset += struct.calcsize('>iiii')
fmt_image = '>' + str(image_size) + 'B'
images = np.empty((num_images, num_rows, num_cols))
for i in range(num_images):
if (i + 1) % 10000 == 0:
print('deocde img %d' % (i + 1))
images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((num_rows, num_cols))
offset += struct.calcsize(fmt_image)
return images
def decode_labels(labelPath):
bin_data = open(labelPath, 'rb').read()
offset = 0
magic_number, num_images = struct.unpack_from('>ii', bin_data, offset)
offset += struct.calcsize('>ii')
fmt_image = '>B'
labels = np.empty(num_images)
for i in range(num_images):
if (i + 1) % 10000 == 0:
print('deocde label %d' % (i + 1))
labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
offset += struct.calcsize(fmt_image)
return labels
if __name__ == '__main__':
train_images = decode_images(train_images_idx3_ubyte_file)
train_labels = decode_labels(train_labels_idx1_ubyte_file)
with open('./train.txt', 'w') as f:
for i in range(len(train_images)):
# cv2.imshow("a",train_images[i])
# cv2.waitKey()
cv2.imwrite('./trainData/'+ str(i) + '.jpg ',train_images[i])
strLine = './trainData/'+ str(i) + '.jpg ' + str(int(train_labels[i])) + '\n'
f.write(strLine)
test_images = decode_images(test_images_idx3_ubyte_file)
test_labels = decode_labels(test_labels_idx1_ubyte_file)
with open('./test.txt', 'w') as f:
for i in range(len(test_images)):
# cv2.imshow("a",train_images[i])
# cv2.waitKey()
cv2.imwrite('./testData/' + str(i) + '.jpg ', test_images[i])
strLine = './testData/' + str(i) + '.jpg ' + str(int(test_labels[i])) + '\n'
f.write(strLine)
print('ok')