1.起始,因为TensorFlow优化模型使用方法,引入了tensorflow hub,让使用更简单,但导致后果就是以前的教程基本不能使用,官方例程因为hub的模型基本是使用URL模式加载,刚好上不去,这就尴尬了,需要自己慢慢摸索,对萌新及其不友好,本篇主要记录本次调试过程,方便后人绕坑。本次使用google 已经训练好的模型inception_v3,然后对最后一层进行重新训练,以满足我们需要的分类要求
2.环境要求
安装TensorFlow,CPU和GPU版本都可以,GPU比较快而已,CPU直接使用最新的即可,目前最新的GPU版本对应的cuda_10.0 和 tensorflow-gpu 1.14.0,使用cuda_10.1会出现TF无法使用cuda的问题,怀疑是Anaconda没有及时同步导致。
3.准备工作
(1)前往https://github.com/tensorflow/tensorflow ,下载对应的TensorFlow源码。
(2)安装hub, pip install tensorflow-hub ,并前往https://github.com/tensorflow/hub ,下载对应的tensorflow-hub源码
(3)准备好自己需要分类的图片,按类型划分好文件名字,我这里使用的是官方提供的数据集 flower_photos,需要的自己去下载,不用科学上网
(4).在下载下来的hub源码中找到hub-master\examples\image_retraining文件夹,运行retrain.py,开始训练。不能科学上网的会在这里被卡住,我这里提供一个野生方法,本地化模型,更改模型为本地加载,参考连接https://zhuanlan.zhihu.com/p/64069911。
下载模型文件,示例如下:
模型路径:https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3
下载模型路径:https://storage.googleapis.com/tfhub-modules/google/imagenet/inception_v3/feature_vector/3.tar.gz
下载后模型需要解压才可以正常使用,然后,运行脚本开始训练
python H:\tf_py\hub-master\examples\image_retraining\retrain.py ^
--image_dir H:\tf_py\image_retrain\flower_photos\flower_photos ^
--tfhub_module H:\tf_py\image_retrain\inception\3 ^
--saved_model_dir H:\tf_py\image_retrain\inception\4
pause
4.检测训练好的模型
因为我使用的是鲜花( 玫瑰 郁金香 向日葵 雏菊 蒲公)的分类训练,所以我去百度下了很多这种类型的图片进行测试。
不幸的是TensorFlow上面的测试例程,因为移植等问题,已经对不上这个模型的测试例程了,于是我自己码了一个心塞。
示例代码:
import tensorflow as tf
import tensorflow_hub as hub
import os
import re
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorflow.keras import layers
saved_model_dir = 'image_retrain/inception/4'
label_lookup_path = 'image_retrain/output_labels.txt'
image_path = 'image_retrain/test/'
class NodeLookup(object):
def __init__(self):
self.node_lookup = self.load(label_lookup_path)
def load(self,label_lookup_path):
proto_as_ascii_lines = tf.gfile.GFile(label_lookup_path).readlines()
node_id_to_name = {}
#一行一行读取数据
for uid,line in enumerate(proto_as_ascii_lines):
#去掉换行符
line = line.strip('\n')
node_id_to_name[uid] = line
return node_id_to_name
#传入分类编号1-1000返回分类名称
def id_to_string(self,node_id):
if node_id not in self.node_lookup:
return ''
return self.node_lookup[node_id]
with tf.Session() as sess:
# 如果不知道模型具体信息 可以使用saved_model_cli.py 查看该模型的输入 输出数据格式要求以及关键的Signature签名
#python H:\tf_py\tensorflow-master\tensorflow\python\tools\saved_model_cli.py show --dir ....\mode\ --all
meta_graph_def = tf.saved_model.loader.load(sess,["serve"], saved_model_dir)
graph = tf.get_default_graph()
oputs = sess.graph.get_tensor_by_name('final_result:0')
input_image = sess.graph.get_tensor_by_name('Placeholder:0')
#遍历目录
for root,dirs,files in os.walk(image_path):
for file in files:
#载入图片
image_data = Image.open(os.path.join(root,file)).resize([299,299])
image_data_array = np.array(image_data)/255.0
image_data_shape = np.reshape(image_data_array,[299,299,3])
#传入图片不能是tensor类型 这里使用np转化成矩阵数组格式
#原因出在tf.reshape(),因为网络训练时用placeholder定义了输入格式,所以输入不能用tensor,
#而tf.reshape()返回结果就是一个tensor了,所以输入会报错。
predictions = sess.run(oputs,{input_image:[image_data_shape]})
predictions = np.squeeze(predictions) #转化为一维数据
image_path = os.path.join(root,file)
print(image_path)
plt.imshow(image_data_shape)
plt.axis('off')
plt.show()
#排序 取概率最大的5个值 然后倒序
top_k = predictions.argsort()[-5:][::-1]
node_lookup = NodeLookup()
for node_id in top_k:
#获取分类名称
human_string = node_lookup.id_to_string(node_id)
#获取分类置信度
score = predictions[node_id]
print('%s (score = %.5f)' %(human_string,score))
print()
运行结果:
附录一下错误调试,
GPU的童鞋需要注意,训练的时候很容易出现cudnn错误,解决方法如下:
1.cudnn创建错误,环境没错的话就是显卡内存出错了,修改为按需分配
Problem:Could not create cudnn handle: CUDNN_STATUS_ALLOC_FAILED
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
2.如果不清楚的模型的输入输出,使用saved_model_cli.py 可以解决很多问题,我被这个输入数据卡了3天,才找到这个解决方案。
3.训练和测试时很容易出现莫名其妙的错误,这个时候最好重启一下python服务,或者删除缓存文件,否则你会崩溃的
4.如果能科学上网,尽量科学上网把,太折腾人了