深度学习:TensorFlow2练习---CIFAR数据准备函数

当不使用Keras的`keras.datasets.cifar10.load_data()`直接下载CIFAR-10数据集后,可以通过创建`load_data`函数,自定义路径读取数据。关键步骤包括解析pickle文件,调整数据形状,并确保数据类型正确。最后,将训练集和测试集转换为适当的格式供模型使用。
摘要由CSDN通过智能技术生成

常用处理方法
在进行CIFAR数据集联系时,可以通过直接调用库函数下载cifar10数据集,下载的数据集会被储存在:

C:\Users\xxx\.keras\datasets

该路径下,执行源码解释中的函数就可以直接使用了

  (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

但是如果手动下载从官网上下载一个cifar数据集,放到指定文件夹下,该如果进行处理呢?

解决方法:

首先在项目里建一个DataSet文件夹,我的文件下载好后解压放在里面。
在这里插入图片描述

path = 'DataDet/cifar-10-batches-py'

通过按住ctrl点击cifar10.load_data(),查看该方法的源码:

import os
import numpy as np
from keras.src import backend
from keras.src.datasets.cifar import load_batch
from keras.src.utils.data_utils import get_file
from tensorflow.python.util.tf_export import keras_export

@keras_export("keras.datasets.cifar10.load_data")
def load_data():
'''
************************************************
下面的代码时下载dirname指定的cifar-10-batches-py文件,
我们已经下载好了,所以并不需要这段代码
************************************************
'''
    dirname = "cifar-10-batches-py"
    origin = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    path = get_file(
        dirname,
        origin=origin,
        untar=True,
        file_hash=(  # noqa: E501
            "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce"
        ),
    )
'''
************************************************
下面的代码是处理的核心,通过返回值可以知道该部分代码最终
得到的结果是我们想要的训练集和测试集,下面进行逐条解释
************************************************
'''
    num_train_samples = 50000 #训练集数据50000条

    x_train = np.empty((num_train_samples, 3, 32, 32), dtype="uint8")
#设置numpy数组,uint8类型,shape为(50000,3,32,32)
    y_train = np.empty((num_train_samples,), dtype="uint8")

    for i in range(1, 6):#数据集中有6个data_batch存放了训练集的文件
        fpath = os.path.join(path, "data_batch_" + str(i))#文件名追加
        (
            x_train[(i - 1) * 10000 : i * 10000, :, :, :],
            y_train[(i - 1) * 10000 : i * 10000],
        ) = load_batch(fpath)
        #调用load_batch()函数得到对应的数据存入,通过按住Ctrl可以查看该方法,方法的作用是通过给定的路径文件Returns:A tuple `(data, labels)`.

#同理得到测试集
    fpath = os.path.join(path, "test_batch")
    x_test, y_test = load_batch(fpath)
#改变测试集的形状
    y_train = np.reshape(y_train, (len(y_train), 1))
    y_test = np.reshape(y_test, (len(y_test), 1))
#当前图像是最后一个通道
    if backend.image_data_format() == "channels_last":
        x_train = x_train.transpose(0, 2, 3, 1)#从(50000,3,32,32)改变为(50000,32,32,3)
        x_test = x_test.transpose(0, 2, 3, 1)
#设置类型
    x_test = x_test.astype(x_train.dtype)
    y_test = y_test.astype(y_train.dtype)
#返回数值
    return (x_train, y_train), (x_test, y_test)

关于load_batch()方法的代码为

import _pickle as cPickle


def load_batch(fpath, label_key="labels"):
    """Internal utility for parsing CIFAR data.

    Args:
        fpath: path the file to parse.
        label_key: key for label data in the retrieve
            dictionary.

    Returns:
        A tuple `(data, labels)`.
    """
    with open(fpath, "rb") as f:
        d = cPickle.load(f, encoding="bytes")
        # decode utf8
        d_decoded = {}
        for k, v in d.items():
            d_decoded[k.decode("utf8")] = v
        d = d_decoded
    data = d["data"]
    labels = d[label_key]

    data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels

结合二者得到最终代码

import os
import numpy as np
from keras.src import backend
import _pickle as cPickle

def loadbatch(fpath, label_key="labels"):
    with open(fpath, "rb") as f:
        d = cPickle.load(f, encoding="bytes")
        # decode utf8
        d_decoded = {}
        for k, v in d.items():
            d_decoded[k.decode("utf8")] = v
        d = d_decoded
    data = d["data"]
    labels = d[label_key]

    data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels

def loaddata():
    path='DataSet/cifar-10-batches-py'
    num_train_samples = 50000

    x_train = np.empty((num_train_samples, 3, 32, 32), dtype="uint8") #定义一个NumPy数据存训练集图像数据,形状(50000,3,32,32)
    y_train = np.empty((num_train_samples,), dtype="uint8")#定义一个NumPy数据存训练集标签数据,形状(50000,3,32,32)
#训练集数据处理
    for i in range(1, 6):#遍历5个data_batch文件
        fpath = os.path.join(path, "data_batch_" + str(i))#连接路径名
        (
            x_train[(i - 1) * 10000 : i * 10000, :, :, :],
            y_train[(i - 1) * 10000 : i * 10000],
        ) = loadbatch(fpath)#得到指定路径的文件数据,x_train->data,y_train->labels
# 测试集数据处理
    fpath = os.path.join(path, "test_batch")
    x_test, y_test = loadbatch(fpath)

    y_train = np.reshape(y_train, (len(y_train), 1))
    y_test = np.reshape(y_test, (len(y_test), 1))

    if backend.image_data_format() == "channels_last":
        x_train = x_train.transpose(0, 2, 3, 1)
        x_test = x_test.transpose(0, 2, 3, 1)

    x_test = x_test.astype(x_train.dtype)
    y_test = y_test.astype(y_train.dtype)

    return (x_train, y_train), (x_test, y_test)

测试

import matplotlib.pyplot as plt

(x_train,y_train),(x_test,y_test) = loaddata()#加载cifar数据集并分为训练集和测试集
plt.figure(figsize=(10, 4))  # 创建一个画布,画布大小为宽10、高4(单位为英寸inch)
for i, imgs in enumerate(x_train[:10]):
 # 将整个画布分成2行10列,绘制第i+1个子图。
   plt.subplot(2, 10, i+1)
   plt.imshow(imgs, cmap=plt.cm.binary)
   plt.axis('off')
for i, imgs in enumerate(x_test[:10]):
 # 将整个画布分成2行10列,绘制第i+11个子图。
   plt.subplot(2, 10, i+11)
   plt.imshow(imgs, cmap=plt.cm.binary)
   plt.axis('off')
plt.show()  #使用pycharm的需要加入这行代码才能将图像显示出来

在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值