Python MNIST解压

python 专栏收录该内容
86 篇文章 0 订阅

解压手写数字数据库MNIST,网上找了几个教程,最后自己写了一个


目录

  1. MNIST介绍
  2. struct模块介绍
  3. 解压实现
  4. 相关实现

MNIST介绍

参考:THE MNIST DATABASE

MNIST是手写数字数据库,共有60000张训练图像和10000张测试图像

共有4个文件,保存训练图像和标签文件以及测试图像和标签文件:

train-images-idx3-ubyte.gz:  training set images (9912422 bytes) 
train-labels-idx1-ubyte.gz:  training set labels (28881 bytes) 
t10k-images-idx3-ubyte.gz:   test set images (1648877 bytes) 
t10k-labels-idx1-ubyte.gz:   test set labels (4542 bytes)

这几个文件以IDX文件格式保存,并不是常规的文件格式,需要自己写程序解析

IDX文件格式

IDX文件可用于各种数值类型的向量和多维矩阵等简单格式的保存

基本格式如下:

magic number 
size in dimension 0 
size in dimension 1 
size in dimension 2 
..... 
size in dimension N 
data

刚开始是一个魔法数字,然后接下来是第0维大小,第1维大小...以此类推,最后是字节数据

手写数字文件格式说明

MNIST手写数字文件以big-endian方式存储

标签文件格式如下:

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  60000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

4个字节是magic-number4-8个字节是标签数量,后面每个字节表示一个标签数字(0-9)

图像文件格式如下:

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel

4个字节是magic-number4-8个字节是标签数量,9-12字节是行数(高度),13-16字节是列数(宽度),后面每个字节表示一个像素值


struct模块介绍

参考:struct — Interpret bytes as packed binary data

struct模块是Python内置模块,可以执行字节数组的解析和转换,常用的函数有

struct.pack(format, v1, v2, ...)

该函数用于将v1,v2...等数值转换成字节数组,其中v1,v2...等数值的类型由format指定

struct.unpack(format, buffer)

该函数将字节数组重新解析成相应的数值,以元组形式返回

注意:需要指定正确的format格式

format

参考:Format Strings

format不仅可以指定操作数值的类型,还可以指定保存字节存储方式Byte-Order

示例一:将两个整型数字保存为字节

>>> import struct
>>> byte_arr = struct.pack('ii', 1, 2)
>>> byte_arr
b'\x01\x00\x00\x00\x02\x00\x00\x00'

其中'i'表示将Python Integer类型转换成4字节数组,参考Format Characters

有多少数字需要转换,那就在format字符串中加入多少个'i'

如果要转换多个同类型数字,也可以用如下方式

>>> struct.pack('2i', 1, 2)
b'\x01\x00\x00\x00\x02\x00\x00\x00'

示例二:将示例一的字节数组转换成数字

>>> struct.unpack('2i', byte_arr)
(1, 2)

示例三:指定字节存储模式

format字符串加上大于号(>)表示大端存储(big-endian),小于号(<)表示小端存储(little-endian

>>> struct.pack('>2i', 1, 2)
b'\x00\x00\x00\x01\x00\x00\x00\x02'
>>>
>>> struct.pack('<2i', 1, 2)
b'\x01\x00\x00\x00\x02\x00\x00\x00'

解压实现

实现如下:

# -*- coding: utf-8 -*-

import os
import struct
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

__author__ = 'zj'

# MNIST文件目录
data_dir = "C:\\datasets\\mnist"
# 结果文件名
result_file_name = 'decompress_mnist'

train = 'train'
test = 'test'
train_image_file = 'train-images.idx3-ubyte'
train_label_file = 'train-labels.idx1-ubyte'
test_image_file = 't10k-images.idx3-ubyte'
test_label_file = 't10k-labels.idx1-ubyte'


def decompress_label(label_file_path):
    """
    解压标签文件
    :param label_file_path: 标签文件路径
    :return: list
    """
    with open(label_file_path, 'rb') as f:
        magic_number = int.from_bytes(f.read(4), byteorder='big')
        image_number = int.from_bytes(f.read(4), byteorder='big')
        # print(magic_number)
        # print(image_number)
        fmt = '>B'
        byte_num = 1
        label_list = []
        for i in range(image_number):
            label_list.append(struct.unpack(fmt, f.read(byte_num))[0])
        # print(number_list)
        return label_list


def decompress_image(image_file_path):
    """
    解压图像文件
    :param image_file_path: 图像文件路径
    :return:
    """
    with open(image_file_path, 'rb') as f:
        magic_number = int.from_bytes(f.read(4), byteorder='big')
        image_number = int.from_bytes(f.read(4), byteorder='big')
        height = int.from_bytes(f.read(4), byteorder='big')
        width = int.from_bytes(f.read(4), byteorder='big')
        print('number: %d height: %d  width: %d' % (image_number, height, width))
        img_len = height * width
        fmt = '>' + str(img_len) + 'B'

        image_list = []
        for i in range(image_number):
            img = struct.unpack(fmt, f.read(img_len))
            img = np.reshape(img, (height, width))
            image_list.append(img)
        return image_list


def decompress_image_save(image_file_path, label_list, res_dir):
    """
    解压图像文件并保存
    :param image_file_path: 图像文件路径
    :param label_list: 标签列表
    :param res_dir: 结果路径
    :return:
    """
    with open(image_file_path, 'rb') as f:
        magic_number = int.from_bytes(f.read(4), byteorder='big')
        image_number = int.from_bytes(f.read(4), byteorder='big')
        height = int.from_bytes(f.read(4), byteorder='big')
        width = int.from_bytes(f.read(4), byteorder='big')
        print('number: %d height: %d  width: %d' % (image_number, height, width))
        img_len = height * width
        fmt = '>' + str(img_len) + 'B'

        nums = [0 for i in range(10)]
        for i in range(image_number):
            img = struct.unpack(fmt, f.read(img_len))
            img = np.reshape(img, (height, width))

            image_dir = os.path.join(res_dir, str(label_list[i]))
            if not os.path.exists(image_dir):
                os.mkdir(image_dir)
            image_path = os.path.join(image_dir, str(nums[label_list[i]]) + ".png")
            cv.imwrite(image_path, img)
            nums[label_list[i]] += 1


def decompress():
    """
    解压图像和标签文件
    :return:
    """
    # 创建结果文件路径
    des_dir = os.path.join(data_dir, result_file_name)
    if not os.path.exists(des_dir):
        os.mkdir(des_dir)
    des_train_dir = os.path.join(des_dir, train)
    if not os.path.exists(des_train_dir):
        os.mkdir(des_train_dir)
    des_test_dir = os.path.join(des_dir, test)
    if not os.path.exists(des_test_dir):
        os.mkdir(des_test_dir)
    print('mkdir result dir ok')

    label_list = decompress_label(os.path.join(data_dir, train_label_file))
    decompress_image_save(os.path.join(data_dir, train_image_file), label_list, des_train_dir)
    print('load image and label ok')

    label_list = decompress_label(os.path.join(data_dir, test_label_file))
    decompress_image_save(os.path.join(data_dir, test_image_file), label_list, des_test_dir)
    print('load image and label ok')


def show_image(img_list):
    # plt.figure(figsize=(10, 5))  # 设置窗口大小
    plt.figure()
    plt.suptitle('MNIST')  # 图片名称
    rows = 10
    cols = 10
    for i in range(cols):
        for j in range(rows):
            plt.subplot(rows, cols, j * cols + i + 1)
            plt.imshow(img_list[j * cols + i], cmap='gray'), plt.axis('off')
    plt.savefig('./mnit.png')
    plt.show()


if __name__ == '__main__':
    decompress()
    # image_list = decompress_image(os.path.join(data_dir, test_image_file))
    # show_image(image_list[:100])

解压好的文件资源如下:

MNIST handwritten digits


相关实现

Python读取mnist

使用Python解析MNIST数据集(IDX文件格式)

python-mnist 0.6

  • 0
    点赞
  • 2
    评论
  • 2
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

©️2021 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值