Tensotflow1.0入门(七)-tensorflow模型保存和加载模型

tensorflow的API提供以下两种方式来存储和加载模型
1.生成检查点文件(checkpoint file),扩展名一般为.ckpr,通过tf.train.Saver来保存,包含权重和其他在程序中定义的变量,但不包含图结构。
2.生成图结构,扩展名一般为.pb,使用tf.train.write_graph()保存,只包含图结构,不包含权重
所以一般都是两个一起结合使用

例子:

# coding=utf-8
'''
Created on 2019年4月5日

@author: admin
'''

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

mnist = input_data.read_data_sets('mnist', one_hot=True)

trX,trY,teX,teY = mnist.train.images,mnist.train.labels,mnist.train.images,mnist.test.labels

x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])

w1 = tf.Variable(tf.random_normal([784,625]))
    
w2 = tf.Variable(tf.random_normal([625,625]))

o = tf.Variable(tf.random_normal([625,10]))

keep_prob1 = tf.placeholder("float")
keep_prob2 = tf.placeholder("float")

def model(x,w1,w2,o):
    
    x = tf.nn.relu(tf.matmul(x,w1))
    x = tf.nn.dropout(x,keep_prob1)
    
    
    x = tf.nn.relu(tf.matmul(x,w2))
    x = tf.nn.dropout(x,keep_prob2)

    return tf.matmul(x,o)
pred = model(x,w1,w2,o)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y),name="cost")
train_op = tf.train.RMSPropOptimizer(0.001,0.9,name="train_op").minimize(cost)
#虽然前面命名为cost,但是最好还是打印出来看一下名称
print(cost)
# Tensor("cost:0", shape=(), dtype=float32)
print(train_op)
predict_op = tf.argmax(pred,1)

# 保存模型超参数
ckpt_dir = "modle_ckpt"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)
model_saver = tf.train.Saver()
# 保存模型图
graph_dir = "modle_graph"
if not os.path.exists(graph_dir):
    os.makedirs(graph_dir)

with tf.Session() as sess:
    tf.initialize_all_variables().run()
    #保存图模型
    tf.train.write_graph(sess.graph_def, graph_dir, 'graph.pbtxt',as_text=False)
    #加载图模型
    with tf.gfile.FastGFile("modle_graph/graph.pbtxt","rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        eles = tf.import_graph_def(graph_def)
#         print(eles)
        cost1 = sess.graph.get_tensor_by_name("cost:0")
        print(cost1)
        train_op1 = sess.graph.get_operation_by_name("train_op")
        print(train_op1)
    #加载已经保存好的模型超参数
#     ckpt = tf.train.get_checkpoint_state(ckpt_dir)
#     print(ckpt.model_checkpoint_path)
#     model_saver.restore(sess,ckpt.model_checkpoint_path)
        for step in range(1):
            for start,end in zip(range(0,len(trX),128),range(128,len(trX)+1,128)):
                loss,_ = sess.run([cost1,train_op1],feed_dict={x:trX[start:end],y:trY[start:end],keep_prob1:0.8,keep_prob2:0.8})
                print(loss)
            print("step",step)
            #保存模型
            model_saver.save(sess, ckpt_dir+"/model.ckpt", global_step=step)

Tensorflow很多训练好的经典模型(图像,语言)等可以在https://github.com/tensorflow/models中下载(一般在README.MD中有链接),其中Caffe也有很多训练好的模型,可以使用https://github.com/ethereon/caffe-tensorflow来进行Caffe到tensorflow的模型转换

参考:《TensorFlow 技术解析与实战》

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值