tensorflow模型保存与加载

一.tensorflow模型保存

模型保存的例子:

import tensorflow as tf  
import numpy as np
with tf.name_scope('train'):
    w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')  
    w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')  
saver = tf.train.Saver()  
with tf.Session() as sess:
    for t in range(3):
        sess.run(tf.global_variables_initializer())  
        saver.save(sess, 'D:\\tuxiang\\hhh\\my_test_model',t)  

其中,t指代的是迭代次数,保存模型时会将迭代次数追加到模型名称后面。
可以更改 tf.train.Saver()中参数max_to_keep的值来设置需要保存的模型数量,也可以设置需要保存的参数,不必将所有的参数都保存,设置方法如下:

# 获取指定scope的tensor
need_save = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')
# 初始化saver时,传入一个var_list的参数,即需要保存的参数
saver = tf.train.Saver(need_save)

保存结果如下:
![在这里插入图片描述](https://img-blog.csdnimg.cn/20190827215926468.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L01hcmtfMjAxOA==,size_16,color_FFFFFF,t_70

二.tensorflow模型加载

将模型保存后,我们可以直接调用已保存的模型来对目标数据集进行测试,不必再从头开始训练。
记载上述模型的例子:

1.通过重新创建相同网络(将之前的代码复制过来),并将其作为原始模型。

代码:

import tensorflow as tf  
import numpy as np
with tf.name_scope('train'):
    w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')  
    w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2') 
#     w3 = tf.Variable(tf.random_normal(shape=[5]), name='w3') 
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 初始w1
    print(sess.run('train/w1:0'))
    saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
    # 赋值后的w1
    print(sess.run('train/w1:0'))
    # 确认
    print(sess.run(w1))

在这里插入图片描述
在加载过程中,如果此网络和预加载网络图是一致的则不必初始化全局或局部变量。
通过这种方式恢复图模型,由于saver = tf.train.Saver()在会话外,代表的是当前网络的图,所以该图必须和原始模型的图结构一致才可加载,我试了在“train”中添加了变量w3,在加载参数的过程中会报错。

注:也可以在会话中使用

# 'train'对应的是加载的模型变量所对应的起始name_scope,
# 模型的name_scope和需要加载参数的name_scope保持一致(图模型一致)
tf.train.Saver([var for var in tf.global_variables() if var.name.startswith('train')]) \
            .restore(sess,' D:\\tuxiang\\hhh\\my_test_model-1')

来加载一个已构好图的网络对应的模型的全部参数,并可参与训练,我在写LSTM的时候用过上述代码,但在关于w1,w2的初始化时未能实现。
在这里插入图片描述

2.使用tf.import_meta_graph(path)将在.meta文件中定义的网络载入到当前图,然后使用特restore恢复参数
import tensorflow as tf  
import numpy as np
with tf.Session() as sess:
    saver =tf.train.import_meta_graph('D:\\tuxiang\\hhh\\my_test_model-1.meta')
    saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
    print(sess.run('train/w1:0'))  

由于使用了tf.train.import_meta_graph(),不必将之前的网络在此处重写一遍。通过这种方式导入训练好的模型会将模型的所有参数导入。

注:如果已经重新构建了网络,又把之前的图加载进来,即使用tf.import_meta_graph(),然后再去restore网络的参数,则预训练模型的参数是不会被加载进来的。
3.只加载指定变量

为了保存或加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或加载的变量。
比如在加载模型时候使用saver = tf.train.Saver([w1]),则只有变量w1会被加载进来:

import tensorflow as tf  
import numpy as np

with tf.variable_scope('train'):
    w1 = tf.get_variable('w1', shape = [2])  
    w2 = tf.get_variable( name='w2',shape=[2])  
    w3 = tf.get_variable( name='w3',shape=[2])
    
saver = tf.train.Saver([w1])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) 
    print(sess.run('train/w1:0'))
    print(sess.run('train/w2:0'))
    saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
    print(sess.run('train/w1:0'))
    print(sess.run('train/w2:0'))
# 输出结果
[0.92955863 0.4762975 ]
[0.88360226 0.48021615]
INFO:tensorflow:Restoring parameters from D:\tuxiang\hhh\my_test_model-1
[-0.22390778  1.227742  ]
[0.88360226 0.48021615]

除了可以选取需要加载或保存的变量,tf.train.Saver还可以支持在保存或加载时给变量重新命名。
例如:声明的变量名称和模型中保存的不一样。

import tensorflow as tf  
import numpy as np

# 声明的变量和模型中已保存变量的名称不同

w1 = tf.get_variable('w1_1', shape = [2])  
w2 = tf.get_variable( name='w2_1',shape=[5])  
w3 = tf.get_variable( name='w3_1',shape=[2])

# 如果直接使用 tf.train.Saver()来加载则会报变量找不到的错误   
# 此时使用一个字典来直接重命名变量就可加载原来的模型了,这个
# 字典指定原来名称为"w1"的变量现在加载到w1中(名称为w1_1)
    
saver = tf.train.Saver({'w1':w1,"w2":w2})
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) 
    print(sess.run('w1_1:0'))
    print(sess.run('w2_1:0'))
    saver.restore(sess,'D:\\tuxiang\\hhh\\new\\my_test_model-1')
    print(sess.run('w1_1:0'))
    print(sess.run('w2_1:0'))
*************output********************
[0.03663075 0.6752244 ]
[-0.7689313   0.74632037  0.6029091  -0.01121217  0.64495254]
INFO:tensorflow:Restoring parameters from D:\tuxiang\hhh\new\my_test_model-1
[ 1.3500427  -0.08136963]
[-0.17664449  0.6879551   1.0955236  -1.7334721   1.4500109 ]

参考资料:

1.tensorflow从已经训练好的模型中,恢复(指定)权重(构建新变量、网络)并继续训练(finetuning)
2.TensorFlow中tf.train.Saver类说明.

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

还是少年呀

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值