1.首先,进入mnist手写数字数据集官方下载地址,下载以下4个文件:
分别是训练集图片、训练集标签、数据集图片、数据集标签
2.将下载的文件解析成图片(.png)
代码:
# encoding: utf-8
import numpy as np
import struct
import os
import cv2
import matplotlib.pyplot as plt
# 训练集文件
# 测试集文件
test_images_idx3_ubyte_file = 'D:/minist/t10k-images.idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = 'D:/minist/t10k-labels.idx1-ubyte'
def decode_idx3_ubyte(idx3_ubyte_file):
with open(idx3_ubyte_file, 'rb') as f:
print('解析文件:', idx3_ubyte_file)
fb_data = f.read()
offset = 0
fmt_header = '>iiii' # 以大端法读取4个 unsinged int32
magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, fb_data, offset)
print('魔数:{},图片数:{}'.format(magic_number, num_images))
offset += struct.calcsize(fmt_header)
fmt_image = '>' + str(num_rows * num_cols) + 'B'
images = np.empty((num_images, num_rows, num_cols))
for i in range(num_images):
im = struct.unpack_from(fmt_image, fb_data, offset)
images[i] = np.array(im).reshape((num_rows, num_cols))
offset += struct.calcsize(fmt_image)
return images
def decode_idx1_ubyte(idx1_ubyte_file):
with open(idx1_ubyte_file, 'rb') as f:
print('解析文件:', idx1_ubyte_file)
fb_data = f.read()
offset = 0
fmt_header = '>ii' # 以大端法读取两个 unsinged int32
magic_number, label_num = struct.unpack_from(fmt_header, fb_data, offset)
print('魔数:{},标签数:{}'.format(magic_number, label_num))
offset += struct.calcsize(fmt_header)
labels = []
fmt_label = '>B' # 每次读取一个 byte
for i in range(label_num):
labels.append(struct.unpack_from(fmt_label, fb_data, offset)[0])
offset += struct.calcsize(fmt_label)
return labels
def check_folder(folder):
"""检查文件文件夹是否存在,不存在则创建"""
if not os.path.exists(folder):
os.mkdir(folder)
print(folder)
else:
if not os.path.isdir(folder):
os.mkdir(folder)
def export_img(exp_dir, img_ubyte, lable_ubyte):
"""
生成数据集
"""
check_folder(exp_dir)
images = decode_idx3_ubyte(img_ubyte)
labels = decode_idx1_ubyte(lable_ubyte)
nums = len(labels)
for i in range(nums):
img_dir = os.path.join(exp_dir, str(labels[i]))
check_folder(img_dir)
img_file = os.path.join(img_dir, str(i)+'.png')
imarr = images[i]
cv2.imwrite(img_file, imarr)
def parser_mnist_data(data_dir):
train_dir = os.path.join(data_dir, 'train')
train_img_ubyte = os.path.join(data_dir, 'train-images.idx3-ubyte')
train_label_ubyte = os.path.join(data_dir, 'train-labels.idx1-ubyte')
export_img(train_dir, train_img_ubyte, train_label_ubyte)
test_dir = os.path.join(data_dir, 'test')
test_img_ubyte = os.path.join(data_dir, 't10k-images.idx3-ubyte')
test_label_ubyte = os.path.join(data_dir, 't10k-labels.idx1-ubyte')
export_img(test_dir, test_img_ubyte, test_label_ubyte)
if __name__ == '__main__':
data_dir = 'D:/minist/'
parser_mnist_data(data_dir)
将代码中的路径换成自己所存放和要存放的路径即可
解析后是10个文件夹,每个文件夹里是对应的手写数字图片
3.将图片转成csv格式,方便程序读取
代码:
import csv,os,cv2
def convert_img_to_csv(img_dir):
#设置需要保存的csv路径
with open("D:\minist\enhance.csv","w")as f:
#设置csv文件的列名
column_name = ["label"]
column_name.extend(["pixel%d"%i for i in range(32*32)])
#将列名写入到csv文件中
writer = csv.writer(f)
writer.writerow(column_name)
#该目录下有9个目录,目录名从0-9
for i in range(10):
#获取目录的路径
img_temp_dir = os.path.join(img_dir,str(i))
#获取该目录下所有的文件
img_list = os.listdir(img_temp_dir)
#遍历所有的文件名称
for img_name in img_list:
#判断文件是否为目录,如果为目录则不处理
if not os.path.isdir(img_name):
#获取图片的路径
img_path = os.path.join(img_temp_dir,img_name)
#因为图片是黑白的,所以以灰色读取图片
img = cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
#图片标签
row_data = [i]
#获取图片的像素
row_data.extend(img.flatten())
#将图片数据写入到csv文件中
writer.writerow(row_data)
if __name__ == "__main__":
#将该目录下的图片保存为csv文件
convert_img_to_csv(r"D:/minist/train")