CIFAR-10 Dataset



CIFAR-10 数据集介绍

CIFAR-10 数据集由60000张 32x32 的彩色图片构成,其中每6000张图片一个类别,共10个类别。其中,训练集包含50000张图片,测试集10000张图片。

数据集被分割为5个训练 batch 和1个测试 batch。训练 batch 包含从每个类别中抽取的 5000张图片,测试 batch 包含从每个类别中随机抽取的100张图片。

CIFAR-10 数据集示例如下:

下载数据集前往官方网站 直接下载即可。


查看 CIFAR-10 数据集

官方网站 分别给出了 Python2 和 Python3 的读取数据集代码。下面是 Python3 读取数据集的代码:

def unpickle(file):
    """
    It is a function to unpickle the batch file.
    :param file: data_batch_1-5
    :return:
    """
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


绘制 CIFAR-10 数据集内图片

在官网上下载的数据集文件夹内,有 data_batch_1, …, data_batch_55个文件存放用于训练的图片。

使用 unpickle 函数可以读取出文件内容,但是是以 numpy.ndarray 形式展现的,不是可视化的图片。本步骤的目的,是把 ndarray 形式转化为可视的图片。

因为 CIFAR-10 数据集内是彩色图片,所以是RGB形式的,输入具有3个通道。将矩阵进行一个 reshape 操作,方便绘制。

转换代码如下:

def cifar10_plot(data, meta, im_idx=0):
    # Get the image data np.ndarray
    im = data[b'data'][im_idx, :]

    im_r = im[0:1024].reshape(32, 32)
    im_g = im[1024:2048].reshape(32, 32)
    im_b = im[2048:].reshape(32, 32)

    # 1-D arrays.shape = (N, ) ----> reshape to (1, N, 1)
    # 2-D arrays.shape = (M, N) ---> reshape to (M, N, 1)
    img = np.dstack((im_r, im_g, im_b))
    # img.shape = (32, 32, 3)

    print("shape: ", img.shape)
    print("label: ", data[b'labels'][im_idx])
    print("category:", meta[b'label_names'][data[b'labels'][im_idx]])

    plt.imshow(img)
    plt.show()


完整代码

我们想通过命令行的格式,输入 0-49999 内的数字,实现对训练集图片的任意展示。完整代码如下;

# coding=utf8
"""
@author: Yantong Lai
@date: 08/21/2019
"""

import pickle
import matplotlib.pyplot as plt
import argparse
import numpy as np
import os

CIFAR10 = "../data/cifar-10-batches-py/"

parser = argparse.ArgumentParser("Plot training images in CIFAR10 dataset.")
parser.add_argument("-i", "--image", type=int, default=0,
                    help="Index of the image in ")
args = parser.parse_args()


def unpickle(file):
    """
    It is a function to unpickle the batch file.
    :param file: data_batch_1-5
    :return:
    """
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def cifar10_plot(data, meta, im_idx=0):
    # Get the image data np.ndarray
    im = data[b'data'][im_idx, :]

    im_r = im[0:1024].reshape(32, 32)
    im_g = im[1024:2048].reshape(32, 32)
    im_b = im[2048:].reshape(32, 32)

    # 1-D arrays.shape = (N, ) ----> reshape to (1, N, 1)
    # 2-D arrays.shape = (M, N) ---> reshape to (M, N, 1)
    img = np.dstack((im_r, im_g, im_b))
    # img.shape = (32, 32, 3)

    print("shape: ", img.shape)
    print("label: ", data[b'labels'][im_idx])
    print("category:", meta[b'label_names'][data[b'labels'][im_idx]])

    plt.imshow(img)
    plt.show()

def main():
    batch = (args.image // 10000) + 1
    idx = args.image - (batch-1)*10000

    data = unpickle(os.path.join(CIFAR10, "data_batch_" + str(batch)))
    meta = unpickle(os.path.join(CIFAR10, "batches.meta"))

    cifar10_plot(data, meta, im_idx=idx)


if __name__ == "__main__":
    main()

其中,batch 返回的是1-5之间的整数,用以读取 data_batch_batch文件,idx 就是该图片在data_batch_batch 文件内的索引。

在 Terminal 中运行代码:

$ python3 CIFAR10.py -i 23478

运行结果如图:


总结

本文是对 CIFAR10 数据集的简介,以及可视化代码讲解。

项目地址为 https://github.com/icmpnorequest/Pytorch-Learning/blob/master/Python3/CIFAR10.py ,可自行前往下载完整代码。

本人水平有限,文章或代码有不妥之处,请给我留言或者在 Github 上提 issue,如果喜欢我的文章或代码,请给我点赞或 star。

谢谢!

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值