tensorflow模型的存储与加载

在使用tensoflow时我们通常是需要将模型的训练与模型的测试分开,所有在模型测试时,我们需要调用已经训练好的模型进行测试,因此,本文将对tensorflow下模型的存储以及加载进行记录:
1、tensoflow保存下来的模型是什么样?
在训练好神经网络模型后,我们通常会将他保存下来,方便在模型的测试阶段进行调用。那么tensorflow保存下来的模型由什么组成呢?
在这里插入图片描述在模型存储后,会出现四个文件:
checkpoint 文本文件,记录了模型文件的路径信息列表
model.ckpt.data-00000-of-00001 网络权重信息
model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)

2、tensorflow是通过tf.train.saver类来实现模型的存储与调用

模型的存储

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)

以上是tf.train.Saver().save()的基本用法,save()方法还有很多可配置的参数,其具体的实现如下所示:


W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt") 

模型的加载
在tensorflow下模型的加载是存在两中方式:
首先定义一个模型

tf.reset_default_graph()  
 
###——————————————————定义神经网络——————————————————
with tf.name_scope('X_Y_input'):
    X=tf.placeholder(tf.float32, shape=[None,time_step_train,input_size],name="x_input")
    Y_=tf.placeholder(tf.float32, shape=[None,output_size],name="y_input")
 
with tf.name_scope('keep_prob'):
    keep_prob = tf.placeholder(tf.float32,name="keep_prob")
 
with tf.name_scope('lstm'):
#输入层、输出层权重、偏置
    w_in=tf.Variable(tf.random_normal([input_size,rnn_unit]),name="w_in")
    b_in=tf.Variable(tf.constant(0.1,shape=[rnn_unit,]),name="b_in")
    
    w_out=tf.Variable(tf.random_normal([rnn_unit,1]),name="w_out")
    b_out=tf.Variable(tf.constant(0.1,shape=[1,]),name="b_out")
    
    with tf.name_scope('lstm_input'):        
        input_x=tf.reshape(X,[-1,input_size])  
         
    with tf.name_scope('lstm_rnn'):       
       lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)
       lstm_cell=tf.nn.rnn_cell.DropoutWrapper(lstm_cell,input_keep_prob=1.0, output_keep_prob=keep_prob)
       cell=tf.nn.rnn_cell.MultiRNNCell([lstm_cell]*num_layer)   
       init_state=cell.zero_state(batch_size,dtype=tf.float32)    
       output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn_r,initial_state=init_state, dtype=tf.float32)
       output_rnn_last=output_rnn[:,-1,:]
    with tf.name_scope('lstm_out'): 
       pred_out=tf.matmul(output_rnn_last,w_out)+b_out
 
##——————————————————定义误差 学习率 和优化器——————————————————
global_step = tf.Variable(0,name="global_step")  
with tf.name_scope('learning_rate'): 
    learning_rate = tf.train.exponential_decay(lr,global_step,len(batch_index_train),decay_rate, staircase=True)
 
with tf.name_scope('loss_mse'):     
    loss_mse=tf.reduce_mean(tf.square(pred_out-Y_))
tf.summary.scalar('loss_mse',loss_mse)
 
with tf.name_scope('train_op'): 
    train_op=tf.train.AdamOptimizer(learning_rate).minimize(loss_mse,global_step=global_step)
#    train_op=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_mse,global_step=global_step )
 
#——————————————————定义全局变量保存器——————————————————
saver=tf.train.Saver(max_to_keep=20)

1)只加载数据,不加载图结构,只不过在Saver对象实例化之前是需要提前定义好新的图的结构,否则报错`


sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

2)连同图结构一同加载,重新导入模型,不需要再重新定义一遍上述的图

 sess= tf.Session()
   
    saver=tf.train.import_meta_graph(modelfile)   # end with .meta 文件
    graph=tf.get_default_graph()                  
    tesor_name_list=[tensor.name for tensor in graph.as_graph_def().node] # 变量名
 
    X=graph.get_tensor_by_name('X_Y_input/x_input:0')   # 我们需要的输入
 
    Y=graph.get_tensor_by_name('X_Y_input/y_input:0')
 
    keep_prob=graph.get_tensor_by_name('keep_prob/keep_prob:0')   # 需要的参数
 
    pred_out=graph.get_tensor_by_name('lstm/lstm_out/add:0')      # 我们需要的预测输出
    
    module_file = tf.train.latest_checkpoint(model_parfile)   # .meta 所在的文件夹名称
    saver.restore(sess, module_file)
 
    for step in range(batch_num):
        prob=sess.run(pred_out,feed_dict={X:test[step*4096:(step+1)*4096],keep_prob: 1.0})   # 上述重新加载的参数在此处都用到了
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值