# creator : wangdiedang
# time : 2022/6/7 10:13
# filename : Read_File.py
import numpy as np
import struct as st
import matplotlib.pyplot as plt
path = [
'train-images.idx3-ubyte',
'train-labels.idx1-ubyte',
't10k-images.idx3-ubyte',
't10k-labels.idx1-ubyte'
]
def normalize(data): # 将图片像素二值化
data[data != 0] = 1
# 读入图像
def read_idx3(path):
offset = 0 # 定义偏移量
fmt_header = ">4i" # 定义类型 idx3前面有4个整型
raw_bin_data = open(path, "rb").read() # 读入字节数据
magic_number, num_images, num_rows, num_cols = st.unpack_from(fmt_header, raw_bin_data, offset)
image_size = num_rows * num_cols
offset += st.calcsize(fmt_header)
fmt_image = ">" + str(image_size) + "B"
images = np.empty((num_images, num_rows, num_cols)) # 返回一个空矩阵
# print(images.shape)
for i in range(num_images):
images[i] = np.array(st.unpack_from(fmt_image, raw_bin_data, offset)).reshape((num_rows, num_cols))
# 二值化图像矩阵
normalize(images[i])
offset += st.calcsize(fmt_image)
return images
# 读入对应标签
def read_idx1(path):
offset = 0 # 定义偏移量
fmt_header = ">2i" # 定义类型 idx1前面有2个整型
raw_bin_data = open(path, "rb").read() # 读入字节数据
magic_number, num_images = st.unpack_from(fmt_header, raw_bin_data, offset)
offset += st.calcsize(fmt_header)
fmt_image = ">B"
labels = np.empty(num_images) # 返回一个空矩阵
# print(images.shape) 输出值为(10000,)
for i in range(num_images):
labels[i] = st.unpack_from(fmt_image, raw_bin_data, offset)[0]
offset += st.calcsize(fmt_image)
return labels
def read_train_and_test(train_imgs_path, train_labels_path, test_imgs_path, test_labels_path):
# 读入 训练集6w 的 idx3文件
train_imgs = read_idx3(train_imgs_path)
# 读入 训练集6w 的 idx1文件
train_labels = read_idx1(train_labels_path)
# 读入 测试集1w 的 idx3文件
test_imgs = read_idx3(test_imgs_path)
# 读入 测试集1w 的 idx1文件
test_labels = read_idx1(test_labels_path)
normalize(train_imgs)
normalize(test_imgs)
return train_imgs, train_labels, test_imgs, test_labels
# def show_img(imgs, labels):
# for i, item in enumerate(imgs):
# if i >= 15:
# break
# plt.imshow(imgs[i], cmap='gray')
# plt.pause(0.000001)
# plt.show()
# print(labels[i])
def show_img(imgs):
m, r, c = imgs.shape
t_img = np.zeros((1, 8 * c + 2))
for i in range(8):
t = np.zeros((r, 1))
for j in range(8):
t = np.hstack((t, imgs[i*8+j]))
t = np.hstack((t, np.zeros((r, 1))))
t_img = np.vstack((t_img, t))
t_img = np.vstack((t_img, np.zeros((1, 8 * c + 2))))
plt.imshow(t_img)
plt.show()
def read_main():
# 调用read_train_and_test读入完整数据矩阵
return read_train_and_test(path[0], path[1], path[2], path[3])
if __name__ == '__main__':
read_main()
手写读取mnist数据集,idx3和idx1文件
最新推荐文章于 2022-11-22 16:26:05 发布