CIFAR-10 数据转为图片-python

"""
CIFAR-10 是 32X32 的彩色图片,共有10个类别,每个类别6000张图片,50000张训练图片(均分为5个batch),10000张测试图片(每个类别选1000张)
将 CIFAR-10 转为 png
"""

import os
import pickle

import numpy as np
from imageio import imwrite

# 数据存放的根目录
base_dir = r'H:\DataStore'
# cifar-10 存放位置
data_dir = os.path.join(base_dir, 'cifar-10-batches-py')
# 训练图片目录
train_dir = os.path.join(base_dir, 'cifar-10-train-png')
# 测试图片目录
test_dir = os.path.join(base_dir, 'cifar-10-test-png')

# 这里不进行训练图片的生成
Train = False
Test = True


# 反序列化
def unpickle(file_path):
    with open(file_path, 'rb') as f:
        _obj = pickle.load(f, encoding='bytes')
    return _obj


# 目录不存在时创建一个
def create_dir(dir_path):
    if not os.path.isdir(dir_path):
        os.makedirs(dir_path)


def get_label_names():
    _label_names_obj = unpickle(os.path.join(data_dir, 'batches.meta'))
    return _label_names_obj[b'label_names']


def save_images(i, obj, class_num, label_names, dir_path):
    # 通道是 红、绿、蓝
    # 一定要使用 b'' 的方式,因为 obj 是 bytes 编码的
    img = np.reshape(obj[b'data'][i], (3, 32, 32))
    # 保存为图片使用 (height, width, channel) 格式
    img = img.transpose(1, 2, 0)
    # 获取当前图片的类别下标 0-9
    label_idx = obj[b'labels'][i]
    # 获取当前图片的名称
    _label_name: str = label_names[label_idx].decode()
    train_dir_label_name_path = os.path.join(dir_path, _label_name)
    create_dir(train_dir_label_name_path)
    # 图片对应的类别数量+1
    class_num[label_idx] += 1
    _image_name = str(class_num[label_idx]) + '.png'
    image_path = os.path.join(train_dir_label_name_path, _image_name)
    # 写入图片
    imwrite(image_path, img)


if __name__ == '__main__':
    _label_names = get_label_names()
    if Train:
        # 累计每个类别的数量
        train_class_num = [0] * 10
        for i in range(1, 6):
            data_batch_path = os.path.join(data_dir, 'data_batch_' + str(i))
            # k: data、labels
            train_batch_obj = unpickle(data_batch_path)
            print("{} is loading...".format(data_batch_path))
            # 每个batch中有10000张图片
            for j in range(0, 10000):
                save_images(j, train_batch_obj, train_class_num, _label_names, train_dir)
        print('train loaded')
    if Test:
        test_class_num = [0] * 10
        test_data_path = os.path.join(data_dir, 'test_batch')
        test_obj = unpickle(test_data_path)
        for i in range(10000):
            save_images(i, test_obj, test_class_num, _label_names, test_dir)
        print('test loaded')

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值