Tensorflow Save

保存为四个文件:

my-model.ckpt.meta          保存整个计算图的结构

my-model.ckpt.data-*        保存模型中每个变量的取值

my-model.ckpt.index

checkpoint                          记录目录下所有模型文件列表


.ckpt模型     图结构.meta与变量值.ckpt分离

from __future__ import print_function
import tensorflow as tf
import numpy as np


'''*********************自定义图运算******************'''
'''*********************自定义图运算******************'''
'''*********************自定义图运算******************'''


'''
#**********************************************在一张图、会话中存入 再载入 变量************************************
tf.reset_default_graph()
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))
    print(W.name,b.name)
    save_path = saver.save(sess, "my_net/save_net.ckpt")
    print("Save to path: ", save_path)
    saver.restore(sess, "my_net/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))
    print(W.name,b.name)

tf.reset_default_graph()  #不然以下的w名称为 weight_1     
'''
'''
#***********************************************在不同图、会话中存入 载入变量**************************************
#----------------------------------save------------------
tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')             #导出张量名为weights:0  计算节点名weights    
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
print(W.name)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess, "my_net/save_net.ckpt")
    print("Save to path: ", save_path)



#---------------------------------reload-----------------

tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
#1 WW = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights2")
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
#print(W.name)
# 自己定义图上运算 参数无需初始化 而将值直接按名称加载进来

saver = tf.train.Saver()
#2 saver = tf.train.Saver([W])              # **************************[对应名称的变量的张量] 列表形式只获取部分变量
#1 saver = tf.train.Saver({'weights':WW})   #***********************{‘原名’: }形式 重命名变量  将原名weight的值放入WW中     %名字无需加上 :0 部分  --计算节点
with tf.Session() as sess:
   
    # 提取变量
    saver.restore(sess, "my_net/save_net.ckpt")
    #print("weights:", sess.run(W))
    print("biases:", sess.run(W))
    
'''  

'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''

#**********************************************直接加载图而无需重复定义图上的运算***********************************

tf.reset_default_graph() 

init = tf.global_variables_initializer()
saver = tf.train.import_meta_graph("C:/Users/Administrator/Desktop/my_net/save_net.ckpt.meta")  #加载图 同时 下面导入变量值

with tf.Session() as sess:
    sess.run(init)
    # 提取变量
    saver.restore(sess, "my_net/save_net.ckpt")
    #通过张量的名称来获取张量 )#Tensor names must be of the form "<op_name>:<output_index>".
    print(sess.run(tf.get_default_graph().get_tensor_by_name('weights:0')))                              #%名字需加上 :0 部分   因为是获取张量


.pb模型   freeze的模型,该模型已经是包含图和相应的参数了

import tensorflow as tf
from tensorflow.python.framework import graph_util  


'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''
'''*********************非自定义图运算******************'''

    
''' 

#***********************************************在不同图、会话中存入 载入变量**************************************
#----------------------------------save------------------
tf.reset_default_graph() #!!!!!!!!!!!!!!!!!
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')             #导出张量名为weights:0  计算节点名weights    
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
print(W.name)

init = tf.global_variables_initializer()


with tf.Session() as sess:
    sess.run(init)
    graph_def = tf.get_default_graph().as_graph_def()    #得到当前的图的 GraphDef 部分,==输入层到输出层的计算过程
    output_graph_def = graph_util.convert_variables_to_constants(sess,   #计算图中的变量及其取值通过常量的方式保存于一个文件中  
                                                        graph_def, ['weights'])   ##需要保存【计算节点】的名字    %舍去无用节点 保存该节点下子图及变量值
  
    with tf.gfile.GFile("model/w.pb", 'wb') as f:  #通过 tf.gfile.GFile 进行模型持久化
        f.write(output_graph_def.SerializeToString())   # 序列化输出

'''

#---------------------------------reload-----------------
from tensorflow.python.platform import gfile  

tf.reset_default_graph() #!!!!!!!!!!!!!!!!!

  
with tf.Session() as sess:  
    model_filename = "Model/combined_model.pb"  
    with gfile.FastGFile(model_filename, 'rb') as f:  
        graph_def = tf.GraphDef()  
        graph_def.ParseFromString(f.read())  
  
    result = tf.import_graph_def(graph_def, return_elements=['weights:0'])   #得输出节点的值--【张量】
    print(sess.run(result)) # [array([ 3.], dtype=float32)]  


参考:

http://blog.csdn.net/marsjhao/article/details/72829635  书译

http://blog.csdn.net/michael_yt/article/details/74737489

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值