手写读取mnist数据集,idx3和idx1文件

# creator : wangdiedang
# time : 2022/6/7 10:13
# filename : Read_File.py

import numpy as np
import struct as st
import matplotlib.pyplot as plt

path = [
    'train-images.idx3-ubyte',
    'train-labels.idx1-ubyte',
    't10k-images.idx3-ubyte',
    't10k-labels.idx1-ubyte'
]


def normalize(data):  # 将图片像素二值化
    data[data != 0] = 1


# 读入图像
def read_idx3(path):
    offset = 0  # 定义偏移量
    fmt_header = ">4i"  # 定义类型 idx3前面有4个整型
    raw_bin_data = open(path, "rb").read()  # 读入字节数据
    magic_number, num_images, num_rows, num_cols = st.unpack_from(fmt_header, raw_bin_data, offset)
    image_size = num_rows * num_cols
    offset += st.calcsize(fmt_header)
    fmt_image = ">" + str(image_size) + "B"
    images = np.empty((num_images, num_rows, num_cols))  # 返回一个空矩阵
    # print(images.shape)
    for i in range(num_images):
        images[i] = np.array(st.unpack_from(fmt_image, raw_bin_data, offset)).reshape((num_rows, num_cols))
        # 二值化图像矩阵
        normalize(images[i])
        offset += st.calcsize(fmt_image)

    return images


# 读入对应标签
def read_idx1(path):
    offset = 0  # 定义偏移量
    fmt_header = ">2i"  # 定义类型 idx1前面有2个整型
    raw_bin_data = open(path, "rb").read()  # 读入字节数据
    magic_number, num_images = st.unpack_from(fmt_header, raw_bin_data, offset)
    offset += st.calcsize(fmt_header)
    fmt_image = ">B"
    labels = np.empty(num_images)  # 返回一个空矩阵
    # print(images.shape) 输出值为(10000,)
    for i in range(num_images):
        labels[i] = st.unpack_from(fmt_image, raw_bin_data, offset)[0]
        offset += st.calcsize(fmt_image)

    return labels


def read_train_and_test(train_imgs_path, train_labels_path, test_imgs_path, test_labels_path):
    # 读入 训练集6w 的 idx3文件
    train_imgs = read_idx3(train_imgs_path)
    # 读入 训练集6w 的 idx1文件
    train_labels = read_idx1(train_labels_path)
    # 读入 测试集1w 的 idx3文件
    test_imgs = read_idx3(test_imgs_path)
    # 读入 测试集1w 的 idx1文件
    test_labels = read_idx1(test_labels_path)
    normalize(train_imgs)
    normalize(test_imgs)
    return train_imgs, train_labels, test_imgs, test_labels


# def show_img(imgs, labels):
#     for i, item in enumerate(imgs):
#         if i >= 15:
#             break
#         plt.imshow(imgs[i], cmap='gray')
#         plt.pause(0.000001)
#         plt.show()
#         print(labels[i])
def show_img(imgs):
    m, r, c = imgs.shape
    t_img = np.zeros((1, 8 * c + 2))
    for i in range(8):
        t = np.zeros((r, 1))
        for j in range(8):
            t = np.hstack((t, imgs[i*8+j]))
        t = np.hstack((t, np.zeros((r, 1))))
        t_img = np.vstack((t_img, t))
    t_img = np.vstack((t_img, np.zeros((1, 8 * c + 2))))
    plt.imshow(t_img)
    plt.show()


def read_main():
    # 调用read_train_and_test读入完整数据矩阵
    return read_train_and_test(path[0], path[1], path[2], path[3])


if __name__ == '__main__':
    read_main()

  • 4
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值