TensorFlow-Keras预训练模型测试

在tf.keras.applications里有多个预训练的模型类。这些类继承了tf.keras.Model类,使用比较方便。在实例化这些类之后,程序会自动下载参数,存储于本地硬盘。目前已经有的模型包括:

  1. DenseNet121
  2. DenseNet169
  3. DenseNet201
  4. InceptionResNetV2
  5. InceptionV3
  6. MobileNet
  7. NASNetLarge
  8. NASNetMobile
  9. ResNet50
  10. VGG16
  11. VGG19
  12. Xception

这些人工神经网络模型的类都有各自对应的预处理函数。我在ILSVRC2012的验证集上对各种网络进行了测试。对图像的预处理流程如下:

  1. h=图像的高,w=图像的宽。
  2. 边长=min(h, w),偏置项=abs(h-w)//2。
  3. 如果h>w,裁剪后的图像=原图像[偏置项:偏置项+边长,边长],转到第5步;否则转到第4步。
  4. 裁剪后的图像=原图像[边长,偏置项:偏置项+边长]。
  5. 将裁剪后的图像放缩到模型所需要的大小。
  6. 用对应的preprocess_input函数进行处理。
测试结果如下图所示



大部分预训练模型的测试结果差于论文所声称的结果。只有Inception-ResNet-v2的测试结果好于原论文。这其中的原因有待考究,有可能是我对测试图像的预处理有偏差,或者预训练模型并非官方提供。

各种神经网络模型的论文错误率是取各自多种变体中最高的一个。MobileNet、NASNetLarge和Xception论文数据暂缺。


下面是测试所用的源代码 
# 导入所需模块
import tensorflow as tf 
import os
import tensorflow

import csv
from Nclasses import labels
import numpy as np 
import utils
import matplotlib.image as mpimg
import matplotlib.pyplot as plt 
import cv2
import time


# 导入预处理函数
# preprocess_input = tensorflow.keras.applications.resnet50.preprocess_input
# preprocess_input = tf.keras.applications.densenet.preprocess_input
# preprocess_input = tf.keras.applications.inception_resnet_v2.preprocess_input
# preprocess_input = tf.keras.applications.inception_v3.preprocess_input
# preprocess_input = tf.keras.applications.mobilenet.preprocess_input
preprocess_input = tf.keras.applications.xception.preprocess_input

# 导入模型类
# ResNet50 = tensorflow.keras.applications.ResNet50
# DenseNet121 = tf.keras.applications.DenseNet121
# DenseNet169 = tf.keras.applications.DenseNet169
# DenseNet201 = tf.keras.applications.DenseNet201
# InceptionResNetV2 = tf.keras.applications.InceptionResNetV2
# InceptionV3 = tf.keras.applications.InceptionV3
# MobileNet = tf.keras.applications.MobileNet
# NASNetLarge = tf.keras.applications.NASNetLarge
# VGG16 = tf.keras.applications.VGG16
# VGG19 = tf.keras.applications.VGG19
Xception = tf.keras.applications.Xception

# 导入image子模块
image = tf.keras.preprocessing.image

# 获得验证集图像地址和标签
image_paths, image_labels = utils.get_paths_labels()

# 实例化模型类
# resnet50 = ResNet50()
# model = DenseNet121()
# model = DenseNet169()
# model = DenseNet201()
# model = InceptionResNetV2()
# model = InceptionV3()
# model = MobileNet()
# model = NASNetLarge()
# model = VGG16()
# model = VGG19()
model = Xception()

# 记录top1、top5正确数目
top1_cnt = 0
top5_cnt = 0

# 用于记录测试用时
begin_time = time.clock()

# 记录总图像数目
cnt = 0

for image_path, image_label in zip(image_paths, image_labels):
	# 预处理开始
	raw_image = image.img_to_array(image.load_img(image_path))
	image_copy = np.copy(raw_image)
	shape = image_copy.shape
	h, w = shape[0], shape[1]
	if h > w: 
		h_start = (h - w) // 2
		image_copy = image_copy[h_start:h_start+w, :]
	else:
		w_start = (w - h) // 2
		image_copy = image_copy[:, w_start:w_start+h]
	image_resized = cv2.resize(image_copy, (299, 299), interpolation=cv2.INTER_CUBIC)
	processed_image = preprocess_input(image_resized).reshape((1, 299, 299, -1))

	# 预处理结束,用模型实例进行预测
	res = model.predict(processed_image)
	
	# 处理得到的结果,与标签进行对比
	# argsort()是numpy.ndarray的成员函数,从小到大排序,返回排序好的各元素对应的排序前的下标
	top5 = res.argsort().squeeze()[-1:-6:-1]
	
	if image_label in top5:
		top5_cnt += 1

	if image_label == top5[0]:
		top1_cnt += 1

	cnt += 1

	# 每10000张图片,输出一次测试耗时
	if cnt % 10000 == 0:
		end_time = time.clock()
		print('%d steps: %f' % (cnt, end_time - begin_time))
		begin_time = end_time

print('top1 accuracy:', top1_cnt / 50000)
print('top5 accuracy:', top5_cnt / 50000)
	

下面是utils.py里的代码,用于返回图像地址和标签:

import numpy as np 
import csv
from Nclasses import labels
import os



def get_paths_labels():
	# 数据集地址
	imagenet_path = r"D:\ILSVRC2012"
	# 获取训练图像文件夹名,即原始标签
	raw_labels = os.listdir(os.path.join(imagenet_path, 'img_train'))

	# 将原始标签映射到0-999
	label_dict = {}
	num_labels = np.arange(1000)
	for raw_label, num_label in zip(raw_labels, num_labels):
		label_dict[raw_label] = num_label 

	# 从csv文件中读取每张图像的地址和标签。可自行处理数据集得到这样一个csv文件。
	with open(os.path.join(imagenet_path, 'val_images.csv'), 'r') as csvfile:
		csvfile.readline()
		lines = csvfile.readlines()
		image_paths = []
		image_labels = []
		for line in lines:
			image_path, image_label = line.strip().split(',')
			image_paths.append(image_path)
			image_labels.append(label_dict[image_label])
	return image_paths, image_labels

下面是Nclasses.py的代码,将0-999的索引值映射到真实标签:

#!/usr/bin/python
#coding:utf-8
# 每个图像的真实标签,以及对应的索引值
labels = {
 0: 'tench\n Tinca tinca',
 1: 'goldfish\n Carassius auratus',
 2: 'great white shark\n white shark\n man-eater\n man-eating shark\n Carcharodon carcharias',
 3: 'tiger shark\n Galeocerdo cuvieri',
 4: 'hammerhead\n hammerhead shark',
 5: 'electric ray\n crampfish\n numbfish\n torpedo',
 6: 'stingray',
 7: 'cock',
 8: 'hen',
 9: 'ostrich\n Struthio camelus',
 10: 'brambling\n Fringilla montifringilla',
 11: 'goldfinch\n Carduelis carduelis',
 12: 'house finch\n linnet\n Carpodacus mexicanus',
 13: 'junco\n snowbird',
 14: 'indigo bunting\n indigo finch\n indigo bird\n Passerina cyanea',
 15: 'robin\n American robin\n Turdus migratorius',
 16: 'bulbul',
 17: 'jay',
 18: 'magpie',
 19: 'chickadee',
 20: 'water ouzel\n dipper',
 21: 'kite',
 22: 'bald eagle\n American eagle\n Haliaeetus leucocephalus',
 23: 'vulture',
 24: 'great grey owl\n great gray owl\n Strix nebulosa',
 25: 'European fire salamander\n Salamandra salamandra',
 26: 'common newt\n Triturus vulgaris',
 27: 'eft',
 28: 'spotted salamander\n Ambystoma maculatum',
 29: 'axolotl\n mud puppy\n Ambystoma mexicanum',
 30: 'bullfrog\n Rana catesbeiana',
 31: 'tree frog\n tree-frog',
 32: 'tailed frog\n bell toad\n ribbed toad\n tailed toad\n Ascaphus trui',
 33: 'loggerhead\n loggerhead turtle\n Caretta caretta',
 34: 'leatherback turtle\n leatherback\n leathery turtle\n Dermochelys coriacea',
 35: 'mud turtle',
 36: 'terrapin',
 37: 'box turtle\n box tortoise',
 38: 'banded gecko',
 39: 'common iguana\n iguana\n Iguana iguana',
 40: 'American chameleon\n anole\n Anolis carolinensis',
 41: 'whiptail\n whiptail lizard',
 42: 'agama',
 43: 'frilled lizard\n Chlamydosaurus kingi',
 44: 'alligator lizard',
 45: 'Gila monster\n Heloderma suspectum',
 46: 'green lizard\n Lacerta viridis',
 47: 'African chameleon\n Chamaeleo chamaeleon',
 48: 'Komodo dragon\n Komodo lizard\n dragon lizard\n giant lizard\n Varanus komodoensis',
 49: 'African crocodile\n Nile crocodile\n Crocodylus niloticus',
 50: 'American alligator\n Alligator mississipiensis',
 51: 'triceratops',
 52: 'thunder snake\n worm snake\n Carphophis amoenus',
 53: 'ringneck snake\n ring-necked snake\n ring snake',
 54: 'hognose snake\n puff adder\n sand viper',
 55: 'green snake\n grass snake',
 56: 'king snake\n kingsnake',
 57: 'garter snake\n grass snake',
 58: 'water snake',
 59: 'vine snake',
 60: 'night snake\n Hypsiglena torquata',
 61: 'boa constrictor\n Constrictor constrictor',
 62: 'rock python\n rock snake\n Python sebae',
 63: 'Indian cobra\n Naja naja',
 64: 'green mamba',
 65: 'sea snake',
 66: 'horned viper\n cerastes\n sand viper\n horned asp\n Cerastes cornutus',
 67: 'diamondback\n diamondback rattlesnake\n Crotalus adamanteus',
 68: 'sidewinder\n horned rattlesnake\n Crotalus cerastes',
 69: 'trilobite',
 70: 'harvestman\n daddy longlegs\n Phalangium opilio',
 71: 'sco
  • 2
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值