1. mnist数据集
mnist数据集是一个很经典的数据集,该数据集在这个地方可以下载到。但是呢下载到的图片并不是图片格式的,而是一种二进制的东西,直接读起来很不直观。这就需要将其转换为图像格式。好在网站上给出了其数据的格式。
##1.1 训练集数据
首先来看训练集的样本数据格式
可以看到在图像数据的前面4个字节定义了magic number、图像个数以及图像的长和宽。后面的也就是全部的数据格式了,偏置的话是28*28个字节。
对应的训练集标签其样本数据格式为
可以从上图中可以看出,除了前面两个字节代表magic number和标签数量之外以后全是以一个字节来表示一个标签值。
1.2 测试数据集
测试数据集和训练数据集是类似的结构,唯一的区别便是数据的数量不一样了而已。下图中是其结构
2. 转换代码
# -*- coding: utf-8 -*-
import numpy as np
import struct
from PIL import Image
import matplotlib.pyplot as plt
import os
class DataUtils(object):
def __init__(self, filename=None, outpath=None):
self._filename = filename
self._outpath = outpath
self._tag = '>' # 大端格式
self._twoBytes = 'II'
self._fourBytes = 'IIII'
self._pictureBytes = '784B'
self._labelByte = '1B'
self._twoBytes2 = self._tag + self._twoBytes
self._fourBytes2 = self._tag + self._fourBytes
self._pictureBytes2 = self._tag + self._pictureBytes
self._labelByte2 = self._tag + self._labelByte
self._imgNums = 0
self._LabelNums = 0
def getImage(self):
"""
将MNIST的二进制文件转换成像素特征数据
"""
binfile = open(self._filename, 'rb') #以二进制方式打开文件
buf = binfile.read()
binfile.close()
index = 0
numMagic, self._imgNums, numRows, numCols = struct.unpack_from(self._fourBytes2, buf, index)
index += struct.calcsize(self._fourBytes)
images = []
print('image nums: %d' % self ._imgNums)
for i in range(self._imgNums):
imgVal = struct.unpack_from(self._pictureBytes2, buf, index)
index += struct.calcsize(self._pictureBytes2)
imgVal = list(imgVal)
for j in range(len(imgVal)):
if imgVal[j] > 1:
imgVal[j] = 1
images.append(imgVal)
return np.array(images), self._imgNums
def getLabel(self):
"""
将MNIST中label二进制文件转换成对应的label数字特征
"""
binFile = open(self._filename, 'rb')
buf = binFile.read()
binFile.close()
index = 0
magic, self._LabelNums = struct.unpack_from(self._twoBytes2, buf, index)
index += struct.calcsize(self._twoBytes2)
labels = []
for x in range(self._LabelNums):
im = struct.unpack_from(self._labelByte2, buf, index)
index += struct.calcsize(self._labelByte2)
labels.append(im[0])
return np.array(labels)
def outImg(self, arrX, arrY, imgNums):
"""
根据生成的特征和数字标号,输出png的图像
"""
output_txt = self._outpath + '/img.txt'
output_file = open(output_txt, 'a+')
m, n = np.shape(arrX)
# 每张图是28*28=784Byte
for i in range(imgNums):
img = np.array(arrX[i]) # 需要查看的请在此处乘255
img = img.reshape(28, 28)
outfile = str(i) + "_" + str(arrY[i]) + ".jpg"
print('saving file: %s' % outfile)
txt_line = outfile + " " + str(arrY[i]) + '\n'
output_file.write(txt_line)
img = Image.fromarray(img, '1')
img.save(self._outpath + '/' + outfile)
print('saving file: %s; done' % outfile)
# plt.figure()
# plt.imshow(img, cmap='binary') # 将图像黑白显示
# plt.savefig(self._outpath + "/" + outfile)
output_file.close()
if __name__ == '__main__':
trainfile_X = '../Image/train-images-idx3-ubyte'
trainfile_y = '../Image/train-labels-idx1-ubyte'
testfile_X = '../Image/t10k-images-idx3-ubyte'
testfile_y = '../Image/t10k-labels-idx1-ubyte'
# 加载mnist数据集
train_X, train_img_nums = DataUtils(filename=trainfile_X).getImage()
train_y = DataUtils(filename=trainfile_y).getLabel()
test_X, test_img_nums = DataUtils(testfile_X).getImage()
test_y = DataUtils(testfile_y).getLabel()
# 以下内容是将图像保存到本地文件中
path_trainset = "../Image/imgs_train"
path_testset = "../Image/imgs_test"
if not os.path.exists(path_trainset):
os.mkdir(path_trainset)
if not os.path.exists(path_testset):
os.mkdir(path_testset)
DataUtils(outpath=path_trainset).outImg(train_X, train_y, train_img_nums)
DataUtils(outpath=path_testset).outImg(test_X, test_y, test_img_nums)