训练好一个模型后,将其投入使用,会有在项目初始化后多次加载测试数据的需求,可以采用保存graph的思想实现
(在一个项目中需要加载多个模型同样可用)
另:这条博客接我的上一条https://blog.csdn.net/qq_34470213/article/details/104076898,是在上一个代码的基础上改写的。
1、新建文件test.py,建一个类Model_test,用来保存模型,包括一个初始化方法,用来初始化模型(项目中仅需初始化时调用一次),一个测试调用方法,用来调用模型进行测试(每次测试调用一次)。
class Model_test():
def restore(self):
self.model = Model.LeNet5(1, 5)
path = "D:/model/model/model.ckpt"
self.model.load(path)
def restore_test(self, image_path):
image = Process.process_one(image_path)
sort = self.model.test1(image)
return sort
2、在model.py的类中添加初始化函数和测试函数,这里和之前的测试函数的差别在于拆分开了加载和测试的部分,并且将graph和session保存为了类属性变量。
def load(self, model_path):
self.graph = tf.Graph()
self.sess = tf.Session(graph=self.graph)
with self.sess.graph.as_default():
self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.local_variables_initializer())
saver11 = tf.train.import_meta_graph(model_path+'.meta',
clear_devices=True)
saver11.restore(self.sess, model_path)
def test1(self, image):
x = tf.placeholder(tf.float32, [None, 64, 64, 1], name='x-input')
self.activation = self.graph.get_tensor_by_name('layer6-fc2/add:0')
image = np.array(image) / 255.0
image = np.reshape(image, (-1, 64, 64, 1))
logit = tf.arg_max(self.activation, 1)
y, label = self.sess.run((self.activation, logit), feed_dict={'x-input:0': image})
return label
3、以上两步就可以成功实现了,调用方法为:
tm = test.Model_test()
tm.restore()
……
while(True){
sort = tm.restore_test(pathname[i])
}
……