本文将会介绍如何利用Keras来搭建著名的ResNet神经网络模型,在CIFAR-10数据集进行图像分类。
数据集介绍
CIFAR-10数据集是已经标注好的图像数据集,由Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton三人收集,其访问网址为:https://www.cs.toronto.edu/~kriz/cifar.html 。
CIFAR-10数据集包含60000张尺寸为32x32的彩色图片,共分成10个分类(类别之间互相独立),每个类别一共6000张图片。该数据集划分为训练集和测试集,其中训练集5000张图片,测试集10000张图片。
该数据集分为5个训练批次和1个测试批次,每个批次一共10000张图片。测试批次包含从每个分类中随机选取的1000张图片。训练批次包含剩下的图片,但是每个训练批次的某些类别的图片会比其他类别多。
下图为从每个类别中选取的10张示例图片:
本文中选用的CIFAR-10数据集下载网址为:https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz,文件夹内容如下:
我们尝试着用Python程序读取里面的图片(图片可视化),Python程序代码如下:
# -*- coding: utf-8 -*-
import cv2
import pickle
# 读取文件
fpath = 'cifar-10-batches-py/data_batch_1'
with open(fpath, 'rb') as f:
d = pickle.load(f, encoding='bytes')
data = d[b'data']
labels = d[b'labels']
data = data.reshape(data.shape[0], 3, 32, 32).transpose(0, 2, 3, 1)
# 保存第image_no张图片
strings=['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
image_no = 1000
label = strings[labels[image_no]]
image = data[image_no,:,:,:]
cv2.imwrite('%s.jpg' % label, image)
运行结果如下:
图片虽然比较模糊,但还是可以看出这是一辆车,属于truck类别。
ResNet模型
图像分类中的经典模型为CNN,但CNN随着层数的增加,显示出退化问题
,即深层次的网络反而不如稍浅层次的网络性能;这并非是过拟合导致的,因为在训练集上就显示出退化差距。而ResNet能较好地解决这个问题。
ResNet全名Residual Network,中文名为残差神经网络,曾获得2015年ImageNet的冠军。ResNet的主要思想在于残差块,Kaiming He等设计了一种skip connection(或者shortcut connections)结构,使得网络具有更强的identity mapping(恒等映射)的能力,从而拓展了网络的