在tf.keras.applications里有多个预训练的模型类。这些类继承了tf.keras.Model类,使用比较方便。在实例化这些类之后,程序会自动下载参数,存储于本地硬盘。目前已经有的模型包括:
- DenseNet121
- DenseNet169
- DenseNet201
- InceptionResNetV2
- InceptionV3
- MobileNet
- NASNetLarge
- NASNetMobile
- ResNet50
- VGG16
- VGG19
- Xception
这些人工神经网络模型的类都有各自对应的预处理函数。我在ILSVRC2012的验证集上对各种网络进行了测试。对图像的预处理流程如下:
- h=图像的高,w=图像的宽。
- 边长=min(h, w),偏置项=abs(h-w)//2。
- 如果h>w,裁剪后的图像=原图像[偏置项:偏置项+边长,边长],转到第5步;否则转到第4步。
- 裁剪后的图像=原图像[边长,偏置项:偏置项+边长]。
- 将裁剪后的图像放缩到模型所需要的大小。
- 用对应的preprocess_input函数进行处理。
大部分预训练模型的测试结果差于论文所声称的结果。只有Inception-ResNet-v2的测试结果好于原论文。这其中的原因有待考究,有可能是我对测试图像的预处理有偏差,或者预训练模型并非官方提供。
各种神经网络模型的论文错误率是取各自多种变体中最高的一个。MobileNet、NASNetLarge和Xception论文数据暂缺。
下面是测试所用的源代码
# 导入所需模块
import tensorflow as tf
import os
import tensorflow
import csv
from Nclasses import labels
import numpy as np
import utils
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import cv2
import time
# 导入预处理函数
# preprocess_input = tensorflow.keras.applications.resnet50.preprocess_input
# preprocess_input = tf.keras.applications.densenet.preprocess_input
# preprocess_input = tf.keras.applications.inception_resnet_v2.preprocess_input
# preprocess_input = tf.keras.applications.inception_v3.preprocess_input
# preprocess_input = tf.keras.applications.mobilenet.preprocess_input
preprocess_input = tf.keras.applications.xception.preprocess_input
# 导入模型类
# ResNet50 = tensorflow.keras.applications.ResNet50
# DenseNet121 = tf.keras.applications.DenseNet121
# DenseNet169 = tf.keras.applications.DenseNet169
# DenseNet201 = tf.keras.applications.DenseNet201
# InceptionResNetV2 = tf.keras.applications.InceptionResNetV2
# InceptionV3 = tf.keras.applications.InceptionV3
# MobileNet = tf.keras.applications.MobileNet
# NASNetLarge = tf.keras.applications.NASNetLarge
# VGG16 = tf.keras.applications.VGG16
# VGG19 = tf.keras.applications.VGG19
Xception = tf.keras.applications.Xception
# 导入image子模块
image = tf.keras.preprocessing.image
# 获得验证集图像地址和标签
image_paths, image_labels = utils.get_paths_labels()
# 实例化模型类
# resnet50 = ResNet50()
# model = DenseNet121()
# model = DenseNet169()
# model = DenseNet201()
# model = InceptionResNetV2()
# model = InceptionV3()
# model = MobileNet()
# model = NASNetLarge()
# model = VGG16()
# model = VGG19()
model = Xception()
# 记录top1、top5正确数目
top1_cnt = 0
top5_cnt = 0
# 用于记录测试用时
begin_time = time.clock()
# 记录总图像数目
cnt = 0
for image_path, image_label in zip(image_paths, image_labels):
# 预处理开始
raw_image = image.img_to_array(image.load_img(image_path))
image_copy = np.copy(raw_image)
shape = image_copy.shape
h, w = shape[0], shape[1]
if h > w:
h_start = (h - w) // 2
image_copy = image_copy[h_start:h_start+w, :]
else:
w_start = (w - h) // 2
image_copy = image_copy[:, w_start:w_start+h]
image_resized = cv2.resize(image_copy, (299, 299), interpolation=cv2.INTER_CUBIC)
processed_image = preprocess_input(image_resized).reshape((1, 299, 299, -1))
# 预处理结束,用模型实例进行预测
res = model.predict(processed_image)
# 处理得到的结果,与标签进行对比
# argsort()是numpy.ndarray的成员函数,从小到大排序,返回排序好的各元素对应的排序前的下标
top5 = res.argsort().squeeze()[-1:-6:-1]
if image_label in top5:
top5_cnt += 1
if image_label == top5[0]:
top1_cnt += 1
cnt += 1
# 每10000张图片,输出一次测试耗时
if cnt % 10000 == 0:
end_time = time.clock()
print('%d steps: %f' % (cnt, end_time - begin_time))
begin_time = end_time
print('top1 accuracy:', top1_cnt / 50000)
print('top5 accuracy:', top5_cnt / 50000)
下面是utils.py里的代码,用于返回图像地址和标签:
import numpy as np
import csv
from Nclasses import labels
import os
def get_paths_labels():
# 数据集地址
imagenet_path = r"D:\ILSVRC2012"
# 获取训练图像文件夹名,即原始标签
raw_labels = os.listdir(os.path.join(imagenet_path, 'img_train'))
# 将原始标签映射到0-999
label_dict = {}
num_labels = np.arange(1000)
for raw_label, num_label in zip(raw_labels, num_labels):
label_dict[raw_label] = num_label
# 从csv文件中读取每张图像的地址和标签。可自行处理数据集得到这样一个csv文件。
with open(os.path.join(imagenet_path, 'val_images.csv'), 'r') as csvfile:
csvfile.readline()
lines = csvfile.readlines()
image_paths = []
image_labels = []
for line in lines:
image_path, image_label = line.strip().split(',')
image_paths.append(image_path)
image_labels.append(label_dict[image_label])
return image_paths, image_labels
下面是Nclasses.py的代码,将0-999的索引值映射到真实标签:
#!/usr/bin/python
#coding:utf-8
# 每个图像的真实标签,以及对应的索引值
labels = {
0: 'tench\n Tinca tinca',
1: 'goldfish\n Carassius auratus',
2: 'great white shark\n white shark\n man-eater\n man-eating shark\n Carcharodon carcharias',
3: 'tiger shark\n Galeocerdo cuvieri',
4: 'hammerhead\n hammerhead shark',
5: 'electric ray\n crampfish\n numbfish\n torpedo',
6: 'stingray',
7: 'cock',
8: 'hen',
9: 'ostrich\n Struthio camelus',
10: 'brambling\n Fringilla montifringilla',
11: 'goldfinch\n Carduelis carduelis',
12: 'house finch\n linnet\n Carpodacus mexicanus',
13: 'junco\n snowbird',
14: 'indigo bunting\n indigo finch\n indigo bird\n Passerina cyanea',
15: 'robin\n American robin\n Turdus migratorius',
16: 'bulbul',
17: 'jay',
18: 'magpie',
19: 'chickadee',
20: 'water ouzel\n dipper',
21: 'kite',
22: 'bald eagle\n American eagle\n Haliaeetus leucocephalus',
23: 'vulture',
24: 'great grey owl\n great gray owl\n Strix nebulosa',
25: 'European fire salamander\n Salamandra salamandra',
26: 'common newt\n Triturus vulgaris',
27: 'eft',
28: 'spotted salamander\n Ambystoma maculatum',
29: 'axolotl\n mud puppy\n Ambystoma mexicanum',
30: 'bullfrog\n Rana catesbeiana',
31: 'tree frog\n tree-frog',
32: 'tailed frog\n bell toad\n ribbed toad\n tailed toad\n Ascaphus trui',
33: 'loggerhead\n loggerhead turtle\n Caretta caretta',
34: 'leatherback turtle\n leatherback\n leathery turtle\n Dermochelys coriacea',
35: 'mud turtle',
36: 'terrapin',
37: 'box turtle\n box tortoise',
38: 'banded gecko',
39: 'common iguana\n iguana\n Iguana iguana',
40: 'American chameleon\n anole\n Anolis carolinensis',
41: 'whiptail\n whiptail lizard',
42: 'agama',
43: 'frilled lizard\n Chlamydosaurus kingi',
44: 'alligator lizard',
45: 'Gila monster\n Heloderma suspectum',
46: 'green lizard\n Lacerta viridis',
47: 'African chameleon\n Chamaeleo chamaeleon',
48: 'Komodo dragon\n Komodo lizard\n dragon lizard\n giant lizard\n Varanus komodoensis',
49: 'African crocodile\n Nile crocodile\n Crocodylus niloticus',
50: 'American alligator\n Alligator mississipiensis',
51: 'triceratops',
52: 'thunder snake\n worm snake\n Carphophis amoenus',
53: 'ringneck snake\n ring-necked snake\n ring snake',
54: 'hognose snake\n puff adder\n sand viper',
55: 'green snake\n grass snake',
56: 'king snake\n kingsnake',
57: 'garter snake\n grass snake',
58: 'water snake',
59: 'vine snake',
60: 'night snake\n Hypsiglena torquata',
61: 'boa constrictor\n Constrictor constrictor',
62: 'rock python\n rock snake\n Python sebae',
63: 'Indian cobra\n Naja naja',
64: 'green mamba',
65: 'sea snake',
66: 'horned viper\n cerastes\n sand viper\n horned asp\n Cerastes cornutus',
67: 'diamondback\n diamondback rattlesnake\n Crotalus adamanteus',
68: 'sidewinder\n horned rattlesnake\n Crotalus cerastes',
69: 'trilobite',
70: 'harvestman\n daddy longlegs\n Phalangium opilio',
71: 'scorpion',
72: 'black and gold garden spider\n Argiope aurantia',