Tensorflow2.x图像分类:案例——CIFAR数据集

CIFAR10 数据集由加拿大 Canadian Institute For Advanced Research 发布,它包含了飞
机、汽车、鸟、猫等共 10 大类物体的彩色图片,每个种类收集了 6000 张32 × 32大小图
片,共 6 万张图片。其中 5 万张作为训练数据集,1 万张作为测试数据集。每个种类样片
如图 所示。
关于数据集的解析可参考官网:(也可直接下载,数据集没多大,很快就下完)
https://www.cs.toronto.edu/~kriz/cifar.html
在这里插入图片描述在 TensorFlow 中,同样地,不需要手动下载、解析和加载 CIFAR10 数据集,通过
datasets.cifar10.load_data()函数就可以直接加载切割好的训练集和测试集。例如:

# 在线下载,加载 CIFAR10 数据集
(x,y), (x_test, y_test) = datasets.cifar10.load_data()
# 删除 y 的一个维度,[b,1] => [b]
y = tf.squeeze(y, axis=1)
y_test = tf.squeeze(y_test, axis=1)
# 打印训练集和测试集的形状
print(x.shape, y.shape, x_test.shape, y_test.shape)
# 构建训练集对象,随机打乱,预处理,批量化
train_db = tf.data.Dataset.from_tensor_slices((x,y))
train_db = train_db.shuffle(1000).map(preprocess).batch(128)
# 构建测试集对象,预处理,批量化
test_db = tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db = test_db.map(preprocess).batch(128)
# 从训练集中采样一个 Batch,并观察
sample = next(iter(train_db))
print('sample:', sample[0].shape, sample[1].shape,
tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))

此外,CIFAR-100数据集包含100小类,每小类包含600个图像,其中有500个训练图像和100个测试图像。100类被分组为20个大类。每个图像带有1个小类的“fine”标签和1个大类“coarse”标签。

考虑到网速慢的原因,这里直接将数据集下载,并存放在py文件当前目录。
在这里插入图片描述

实现源码如下:

# encoding: utf-8
from __future__ import print_function

class Cifar10DataReader():
    import os
    import random
    import numpy as np
    import pickle

    def __init__(self, cifar_file, one_hot=False, file_number=1):
        self.batch_index = 0  # 第i批次
        self.file_number = file_number  # 第i个文件数
        self.cifar_file = cifar_file  # 数据集所在dir
        self.one_hot = one_hot
        self.train_data = self.read_train_file()  # 一个数据文件的训练集数据,得到的是一个1000大小的list,
        self.test_data = self.read_test_data()  # 得到1000个测试集数据

    # 读取数据函数,返回dict
    def unpickle(self, file):
        with open(file, 'rb') as fo:
            try:

                dicts = self.pickle.load(fo, encoding='bytes')
            except Exception as e:
                print('load error', e)
            return dicts

    # 读取一个训练集文件,返回数据list
    def read_train_file(self, files=''):
        if files:
            files = self.os.path.join(self.cifar_file, files)
        else:
            files = self.os.path.join(self.cifar_file, 'data_batch_%d' % self.file_number)
        dict_train = self.unpickle(files)
        train_data = list(zip(dict_train[b'data'], dict_train[b'labels']))  # 将数据和对应标签打包
        self.np.random.shuffle(train_data)
        print('成功读取到训练集数据:data_batch_%d' % self.file_number)
        return train_data

    # 读取测试集数据
    def read_test_data(self):
        files = self.os.path.join(self.cifar_file, 'test_batch')
        dict_test = self.unpickle(files)
        test_data = list(zip(dict_test[b'data'], dict_test[b'labels']))  # 将数据和对应标签打包
        print('成功读取测试集数据')
        return test_data

    # 编码得到的数据,变成张量,并分别得到数据和标签
    def encodedata(self, detum):
        rdatas = list()
        rlabels = list()
        for d, l in detum:
            rdatas.append(self.np.reshape(self.np.reshape(d, [3, 1024]).T, [32, 32, 3]))
            if self.one_hot:
                hot = self.np.zeros(10)
                hot[int(l)] = 1
                rlabels.append(hot)
            else:
                rlabels.append(l)
        return rdatas, rlabels

    # 得到batch_size大小的数据和标签
    def nex_train_data(self, batch_size=100):
        assert 1000 % batch_size == 0, 'erro batch_size can not divied!'  # 判断批次大小是否能被整除

        # 获得一个batch_size的数据
        if self.batch_index < len(self.train_data) // batch_size:  # 是否超出一个文件的数据量
            detum = self.train_data[self.batch_index * batch_size:(self.batch_index + 1) * batch_size]
            datas, labels = self.encodedata(detum)
            self.batch_index += 1
        else:  # 超出了就加载下一个文件
            self.batch_index = 0
            if self.file_number == 5:
                self.file_number = 1
            else:
                self.file_number += 1
            self.read_train_file()
            return self.nex_train_data(batch_size=batch_size)
        return datas, labels

    # 随机抽取batch_size大小的训练集
    def next_test_data(self, batch_size=100):
        detum = self.random.sample(self.test_data, batch_size)  # 随机抽取
        datas, labels = self.encodedata(detum)
        return datas, labels


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    Cifar10 = Cifar10DataReader(r'./cifar-10-batches-py2', one_hot=True)
    d, l = Cifar10.nex_train_data()
    print(len(d))
    print(d)
    plt.imshow(d[0])
    plt.show()

飞机:
在这里插入图片描述

当然你也可以选择其他位置:下载下来的数据集放到~/.keras/datasets/ 目录下,然后将文件名改名为cifar-10-batches-py.tar.gz

sudo mv ~/Download/cifar-10-python.tar.gz ~/.keras/datasets/cifar-10-batches-py.tar.gz

解压一下:

tar xvfz cifar-10-batches-py.tar.gz

然后再使用

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

参考文献:
https://www.cnblogs.com/IAMzhuxiaofeng/p/9142582.html
https://www.cs.toronto.edu/~kriz/cifar.html
https://blog.csdn.net/WANG
https://blog.csdn.net/shadowl
https://blog.csdn.net/qq_26593695
https://blog.csdn.net/qq_36895331
https://blog.csdn.net/zhanzi1538/article/details/106878836/

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

源代码杀手

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值