TensorFlow学习笔记之CIFAR10与VGG13实战

上上篇博文TensorFlow学习笔记之Fashion MNIST数据集简单分类我们学习了Fashion MINST集的简单分类,但是Fashion MINIST数据集只保存了图片灰度的信息,不适用输入为RGB三通道的网络模型,此节我们展开CIFAR10与VGG13的实战

1 CIFAR 10数据集

1 CIFAR 10介绍

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( plane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练图片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示。下面这幅图就是列举了10各类,每一类展示了随机的10张图片:
在这里插入图片描述

1.2 CIFAR10数据集的下载

1.2.1 官方下载

http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

1.2.2 keras模块直接加载

在TensorFlow中,不需要手动下载,解析和加载CIFAR10数据集,通过datasets.cifar10.load_data()函数就可以加载切割好的训练集和测试集

from keras import datasets
(x,y), (x_test, y_test) = datasets.cifar10.load_data()

TensorFlow 会自动将数据集下载在 C:\Users\用户名.keras\datasets 路径下,用户可以查
看,也可手动删除不需要的数据集缓存。

2 VGG13

2.1 选取VGG13的原因

CIFAR10图片识别的任务识别的任务不太简单。主要是因为保存的图片的分辨率仅为32x32,部分主体信息较为模糊,有时人肉眼也难以分辨。浅层的神经网络表达呢能力有限,很难训练优化到较好的性能,所以采用VGG13网络,再根据此数据集的特点****修改部分网络结构

2.2 VGG部分网络结构修改

  • List itemVGG原网络输入为224x224,现将网络输入参数调整为32x32。原网络会导致全连接输入特征维度过大,网络参数量过大
  • 3个全连接层的维度调整为[256,64,10]

2.3 VGG13模型结构

在这里插入图片描述
在这里插入图片描述

3代码

3.1 导入必要的库

import tensorflow as tf
from tensorflow.keras import layers, optimizers, datasets, Sequential
import os

3.2 处理部分

def preprocess(x, y):
	# 归一化处理[0~1]
	x= 2 * tf.cast(x, dtype = tf.float32) / 255. - 1
	y = tf.cast(y,dtype = tf.int32)
	return x,y
	#加载数据集
	(x, y), (x_test, y_test) = datasets.cifar10.load_data()
	y = tf.squeeze(y, axis=1)
	y_test = tf.squeeze(y_test, axis=1)
	print(x.shape, y.shape, x_test.shape, y_test.shape)

	test_db = tf.data.Dataset.from_tensor_slices(
  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值