Keras内置的预定义模型
上一节我们讲过了完整的保存模型及其训练完成的参数。
Keras中使用这种方式,预置了多个著名的成熟神经网络模型。当然,这实际是Keras的功劳,并不适合算在TensorFlow 2.0头上。
当前TensorFlow 2.0-alpha版本捆绑的Keras中包含:
densenet
inception_resnet_v2
inception_v3
mobilenet
mobilenet_v2
nasnet
resnet50
vgg16
vgg19
xception
这些模型都已经使用大规模的数据训练完成,可以上手即用,实为良心佳作、码农福利。
在《从锅炉工到AI专家(8)》文中,我们演示了一个使用vgg19神经网络识别图片内容的例子。那段代码并不难,但是使用TensorFlow 1.x的API构建vgg19这种复杂的神经网络可说费劲不小。有兴趣的读者可以移步至原文再体会一下那种纠结。
而现在再做同样的事则是再简单不过了,你完全可以在你同事去茶水间倒咖啡的时间完成一个全功能的可用代码。比如跟上文功能相同的代码如下:
#!/usr/bin/env python3
import tensorflow as tf
from tensorflow import keras
# 载入vgg19模型
from tensorflow.keras.applications import vgg19
from tensorflow.keras.preprocessing import image
import numpy as np
import argparse
# 用于保存命令行参数
FLAGS = None
# 初始化vgg19模型,weights参数指的是使用ImageNet图片集训练的模型
# 每种模型第一次使用的时候都会自网络下载保存的h5文件
# vgg19的数据文件约为584M
model = vgg19.VGG19(weights='imagenet')
def main(imgPath):
# 载入命令行参数指定的图片文件, 载入时变形为224x224,这是模型规范数据要求的
img = image.load_img(imgPath, target_size=(224, 224))
# 将图片转换为(224,224,3)数组,最后的3是因为RGB三色彩图
img = image.img_to_array(img)
# 跟前面的例子一样,使用模型进行预测是批处理模式,
# 所以对于单个的图片,要扩展一维成为(1,224,224,3)这样的形式
# 相当于建立一个预测队列,但其中只有一张图片
img = np.expand_dims(img, axis=0)
# 使用模型预测(识别)
predict_class = model.predict(img)
# 获取图片识别可能性最高的3个结果
desc = vgg19.decode_predictions(predict_class, top=3)
# 我们的预测队列中只有一张图片,所以结果也只有第一个有效,显示出来
print(desc[0])
if __name__ == '__main__':
# 命令行参数处理
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--image_file', type=str, default='pics/bigcat.jpeg',
help='Pic file name')
FLAGS, unparsed = parser.parse_known_args()
main(FLAGS.image_file)
Keras库载入图片文件的代码间接引用了pillow库,所以程序执行前请先安装:pip3 install pillow。
仍然使用原文中的图片尝试识别:
$ ./pic-recognize.py -i pics/bigcat.jpeg
[('n02128385', 'leopard', 0.9778516), ('n02130308', 'cheetah', 0.008372171), ('n02128925', 'jaguar', 0.007467962)]
结果表示,图片是leopard(美洲豹)的可能性为97.79%,是cheetah(猎豹)的可能性为0.84%,是jaguar(美洲虎)的可能性为0.75%。
使用这种方式,在图片识别中,换用其他网络模型非常轻松,只需要替换程序中的三条语句,比如我们将模型换为resnet50:
模型引入,由:
from tensorflow.keras.applications import vgg19
替换为:
from tensorflow.keras.applications import resnet50
模型构建,由:
model = vgg19.VGG19(weights='imagenet')
替换为:
model = resn