tensorflow的API提供以下两种方式来存储和加载模型
1.生成检查点文件(checkpoint file),扩展名一般为.ckpr,通过tf.train.Saver来保存,包含权重和其他在程序中定义的变量,但不包含图结构。
2.生成图结构,扩展名一般为.pb,使用tf.train.write_graph()保存,只包含图结构,不包含权重
所以一般都是两个一起结合使用
例子:
# coding=utf-8
'''
Created on 2019年4月5日
@author: admin
'''
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
mnist = input_data.read_data_sets('mnist', one_hot=True)
trX,trY,teX,teY = mnist.train.images,mnist.train.labels,mnist.train.images,mnist.test.labels
x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
w1 = tf.Variable(tf.random_normal([784,625]))
w2 = tf.Variable(tf.random_normal([625,625]))
o = tf.Variable(tf.random_normal([625,10]))
keep_prob1 = tf.placeholder("float")
keep_prob2 = tf.placeholder("float")
def model(x,w1,w2,o):
x = tf.nn.relu(tf.matmul(x,w1))
x = tf.nn.dropout(x,keep_prob1)
x = tf.nn.relu(tf.matmul(x,w2))
x = tf.nn.dropout(x,keep_prob2)
return tf.matmul(x,o)
pred = model(x,w1,w2,o)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y),name="cost")
train_op = tf.train.RMSPropOptimizer(0.001,0.9,name="train_op").minimize(cost)
#虽然前面命名为cost,但是最好还是打印出来看一下名称
print(cost)
# Tensor("cost:0", shape=(), dtype=float32)
print(train_op)
predict_op = tf.argmax(pred,1)
# 保存模型超参数
ckpt_dir = "modle_ckpt"
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
model_saver = tf.train.Saver()
# 保存模型图
graph_dir = "modle_graph"
if not os.path.exists(graph_dir):
os.makedirs(graph_dir)
with tf.Session() as sess:
tf.initialize_all_variables().run()
#保存图模型
tf.train.write_graph(sess.graph_def, graph_dir, 'graph.pbtxt',as_text=False)
#加载图模型
with tf.gfile.FastGFile("modle_graph/graph.pbtxt","rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
eles = tf.import_graph_def(graph_def)
# print(eles)
cost1 = sess.graph.get_tensor_by_name("cost:0")
print(cost1)
train_op1 = sess.graph.get_operation_by_name("train_op")
print(train_op1)
#加载已经保存好的模型超参数
# ckpt = tf.train.get_checkpoint_state(ckpt_dir)
# print(ckpt.model_checkpoint_path)
# model_saver.restore(sess,ckpt.model_checkpoint_path)
for step in range(1):
for start,end in zip(range(0,len(trX),128),range(128,len(trX)+1,128)):
loss,_ = sess.run([cost1,train_op1],feed_dict={x:trX[start:end],y:trY[start:end],keep_prob1:0.8,keep_prob2:0.8})
print(loss)
print("step",step)
#保存模型
model_saver.save(sess, ckpt_dir+"/model.ckpt", global_step=step)
Tensorflow很多训练好的经典模型(图像,语言)等可以在https://github.com/tensorflow/models中下载(一般在README.MD中有链接),其中Caffe也有很多训练好的模型,可以使用https://github.com/ethereon/caffe-tensorflow来进行Caffe到tensorflow的模型转换
参考:《TensorFlow 技术解析与实战》