前面的教程都只在小模型、小数据库上进行了演示,这次来真正实战一个大型数据库ImageNet。教程会分为三部分:数据增强、模型加载与训练、模型测试,最终在ResNet50上可以达到77.72%的top-1准确率,复现出了ResNet原文的结果。
完整的代码可以在我的github上找到。https://github.com/Apm5/ImageNet_Tensorflow2.0
提供ResNet-18和ResNet-50的预训练模型,以供大家做迁移使用。
链接:https://pan.baidu.com/s/1nwvkt3Ei5Hp5Pis35cBSmA
提取码:y4wo
还提供百度云链接的ImageNet原始数据,但是这份资源只能创建临时链接以供下载,有需要的还请私信联系。下面开始正文。
模型测试
本文着重于测试部分的代码实现。与训练过程类似,首先需要建立模型,然后载入权重。
model = ResNet(50)
model.build(input_shape=(None, 224, 224, 3))
model.load_weights(file_path)
然后在测试阶段需要设置training=False
来控制BN层正确计算.
prediction = model(images, training=False)
同样的,在测试阶段中也可以通过@tf.function
修饰来开启静态图模式,加速计算。
@tf.function
def test_step(model, images, labels):
prediction = model(images, training=False)
ce = cross_entropy_batch(labels, prediction)
return ce, prediction
模型中提供了三种测试脚本:
对ImageNet数据集的中心裁剪测试test.py
对ImageNet数据集的10-crop测试test_10_crop.py
对单张图像的测试test_single_image.py
下面一一进行讲解。
中心裁剪测试
在训练过程中,原始图像经过多种变换进行数据增强后才输入网络,而在测试过程中数据增强是不需要的,我们希望尽可能准确的识别图像,而不是给网络加大难度。所以测试过程的作法通常是先将图像保持长宽比缩放到短边为256大小,然后再裁剪图像中心的224*224大小的区域作为网络输入。
def center_crop(image):
height, width, _ = np.shape(image)
input_height, input_width, _ = c.input_shape
crop_x = (width - input_width) // 2
crop_y = (height - input_height) // 2
return image[crop_y: crop_y + input_height, crop_x: crop_x + input_width, :]
10-crop测试
真正在ImageNet比赛中,为了能取得更高的识别准确率,通常会对同一张图像多次裁剪不同区域,然后综合识别结果给出最终的结果,这一做法在VGG或ResNet原文中都有提到。
具体来说,仍然是将图像图像保持长宽比缩放到短边为256大小,然后裁剪左上、左下、右上、右下和中心5幅图像,然后将图像左右对称后再裁剪得到另外5幅图像,将这10幅图像作为网络的输入,然后将最后一层softmax的结果求平均作为最终的输出。
该做法能在不改变模型本身的前提下,能得到约1~2%的涨点,是一种和多模型融合类似的比赛中的的做法。
单张图像识别
在github的项目中,我还提供了识别单张图像的脚本,需要设置test_single_image.py中的model_path
和image_path
,然后直接运行即可得到结果。
python test_single_image.py
输出包括具体的识别类别和置信度。
----------------------------------------
image: ILSVRC2012_val_00000321.JPEG
classification result:computer_keyboard
confidence:0.7444
----------------------------------------