用python将mnist数据集转csv格式

1.首先,进入mnist手写数字数据集官方下载地址,下载以下4个文件:

在这里插入图片描述
分别是训练集图片、训练集标签、数据集图片、数据集标签

2.将下载的文件解析成图片(.png)

代码:

# encoding: utf-8

import numpy as np
import struct
import os
import cv2
import matplotlib.pyplot as plt

# 训练集文件

# 测试集文件
test_images_idx3_ubyte_file = 'D:/minist/t10k-images.idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = 'D:/minist/t10k-labels.idx1-ubyte'


def decode_idx3_ubyte(idx3_ubyte_file):
    with open(idx3_ubyte_file, 'rb') as f:
        print('解析文件:', idx3_ubyte_file)
        fb_data = f.read()

    offset = 0
    fmt_header = '>iiii'    # 以大端法读取4个 unsinged int32
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, fb_data, offset)
    print('魔数:{},图片数:{}'.format(magic_number, num_images))
    offset += struct.calcsize(fmt_header)
    fmt_image = '>' + str(num_rows * num_cols) + 'B'

    images = np.empty((num_images, num_rows, num_cols))
    for i in range(num_images):
        im = struct.unpack_from(fmt_image, fb_data, offset)
        images[i] = np.array(im).reshape((num_rows, num_cols))
        offset += struct.calcsize(fmt_image)
    return images

def decode_idx1_ubyte(idx1_ubyte_file):
    with open(idx1_ubyte_file, 'rb') as f:
        print('解析文件:', idx1_ubyte_file)
        fb_data = f.read()

    offset = 0
    fmt_header = '>ii'  # 以大端法读取两个 unsinged int32
    magic_number, label_num = struct.unpack_from(fmt_header, fb_data, offset)
    print('魔数:{},标签数:{}'.format(magic_number, label_num))
    offset += struct.calcsize(fmt_header)
    labels = []

    fmt_label = '>B'    # 每次读取一个 byte
    for i in range(label_num):
        labels.append(struct.unpack_from(fmt_label, fb_data, offset)[0])
        offset += struct.calcsize(fmt_label)
    return labels

def check_folder(folder):
    """检查文件文件夹是否存在,不存在则创建"""
    if not os.path.exists(folder):
        os.mkdir(folder)
        print(folder)
    else:
        if not os.path.isdir(folder):
            os.mkdir(folder)


def export_img(exp_dir, img_ubyte, lable_ubyte):
    """
    生成数据集
    """
    check_folder(exp_dir)
    images = decode_idx3_ubyte(img_ubyte)
    labels = decode_idx1_ubyte(lable_ubyte)

    nums = len(labels)
    for i in range(nums):
        img_dir = os.path.join(exp_dir, str(labels[i]))
        check_folder(img_dir)
        img_file = os.path.join(img_dir, str(i)+'.png')
        imarr = images[i]
        cv2.imwrite(img_file, imarr)


def parser_mnist_data(data_dir):

    train_dir = os.path.join(data_dir, 'train')
    train_img_ubyte = os.path.join(data_dir, 'train-images.idx3-ubyte')
    train_label_ubyte = os.path.join(data_dir, 'train-labels.idx1-ubyte')
    export_img(train_dir, train_img_ubyte, train_label_ubyte)

    test_dir = os.path.join(data_dir, 'test')
    test_img_ubyte = os.path.join(data_dir, 't10k-images.idx3-ubyte')
    test_label_ubyte = os.path.join(data_dir, 't10k-labels.idx1-ubyte')
    export_img(test_dir, test_img_ubyte, test_label_ubyte)

if __name__ == '__main__':
    data_dir = 'D:/minist/'
    parser_mnist_data(data_dir)

将代码中的路径换成自己所存放和要存放的路径即可
解析后是10个文件夹,每个文件夹里是对应的手写数字图片
在这里插入图片描述
在这里插入图片描述

3.将图片转成csv格式,方便程序读取

代码:

import csv,os,cv2

def convert_img_to_csv(img_dir):
    #设置需要保存的csv路径
    with open("D:\minist\enhance.csv","w")as f:
        #设置csv文件的列名
        column_name = ["label"]
        column_name.extend(["pixel%d"%i for i in range(32*32)])
        #将列名写入到csv文件中
        writer = csv.writer(f)
        writer.writerow(column_name)
        #该目录下有9个目录,目录名从0-9
        for i in range(10):
            #获取目录的路径
            img_temp_dir = os.path.join(img_dir,str(i))
            #获取该目录下所有的文件
            img_list = os.listdir(img_temp_dir)
            #遍历所有的文件名称
            for img_name in img_list:
                #判断文件是否为目录,如果为目录则不处理
                if not os.path.isdir(img_name):
                    #获取图片的路径
                    img_path = os.path.join(img_temp_dir,img_name)
                    #因为图片是黑白的,所以以灰色读取图片
                    img = cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
                    #图片标签
                    row_data = [i]
                    #获取图片的像素
                    row_data.extend(img.flatten())
                    #将图片数据写入到csv文件中
                    writer.writerow(row_data)


if __name__ == "__main__":
    #将该目录下的图片保存为csv文件
    convert_img_to_csv(r"D:/minist/train")

  • 1
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值