import numpy as np
import matplotlib.pyplot as plt
import os
def read_idx_file(filename):
with open(filename, 'rb') as f:
# 读取文件头
magic_number = int.from_bytes(f.read(4), 'big')
num_items = int.from_bytes(f.read(4), 'big')
# 读取数据
if magic_number == 2051: # 图像文件
num_rows = int.from_bytes(f.read(4), 'big')
num_cols = int.from_bytes(f.read(4), 'big')
data = np.frombuffer(f.read(), dtype=np.uint8)
data = data.reshape(num_items, num_rows, num_cols)
elif magic_number == 2049: # 标签文件
data = np.frombuffer(f.read(), dtype=np.uint8)
else:
raise ValueError("Invalid file format")
return data
def save_images(data, labels, save_dir):
# 创建保存图像的目录
os.makedirs(save_dir, exist_ok=True)
# 按类别保存图像
for i in range(10):
class_dir = os.path.join(save_dir, str(i))
os.makedirs(class_dir, exist_ok=True)
for i in range(len(data)):
image = data[i]
label = labels[i]
class_dir = os.path.join(save_dir, str(label))
image_path = os.path.join(class_dir, f"{i}.png")
plt.imsave(image_path, image, cmap='gray')
# 读取训练数据
train_images = read_idx_file('train-images-idx3-ubyte') #改为绝对地址
train_labels = read_idx_file('train-labels-idx1-ubyte')
# 读取测试数据
test_images = read_idx_file('t10k-images-idx3-ubyte')
test_labels = read_idx_file('t10k-labels-idx1-ubyte')
# 保存训练图像
save_images(train_images, train_labels, 'train_images')
# 保存测试图像
save_images(test_images, test_labels, 'test_images')