一.tensorflow模型保存
模型保存的例子:
import tensorflow as tf
import numpy as np
with tf.name_scope('train'):
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
with tf.Session() as sess:
for t in range(3):
sess.run(tf.global_variables_initializer())
saver.save(sess, 'D:\\tuxiang\\hhh\\my_test_model',t)
其中,t指代的是迭代次数,保存模型时会将迭代次数追加到模型名称后面。
可以更改 tf.train.Saver()中参数max_to_keep的值来设置需要保存的模型数量,也可以设置需要保存的参数,不必将所有的参数都保存,设置方法如下:
# 获取指定scope的tensor
need_save = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')
# 初始化saver时,传入一个var_list的参数,即需要保存的参数
saver = tf.train.Saver(need_save)
保存结果如下:
二.tensorflow模型加载
将模型保存后,我们可以直接调用已保存的模型来对目标数据集进行测试,不必再从头开始训练。
记载上述模型的例子:
1.通过重新创建相同网络(将之前的代码复制过来),并将其作为原始模型。
代码:
import tensorflow as tf
import numpy as np
with tf.name_scope('train'):
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
# w3 = tf.Variable(tf.random_normal(shape=[5]), name='w3')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 初始w1
print(sess.run('train/w1:0'))
saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
# 赋值后的w1
print(sess.run('train/w1:0'))
# 确认
print(sess.run(w1))
在加载过程中,如果此网络和预加载网络图是一致的则不必初始化全局或局部变量。
通过这种方式恢复图模型,由于saver = tf.train.Saver()在会话外,代表的是当前网络的图,所以该图必须和原始模型的图结构一致才可加载,我试了在“train”中添加了变量w3,在加载参数的过程中会报错。
注:也可以在会话中使用
# 'train'对应的是加载的模型变量所对应的起始name_scope,
# 模型的name_scope和需要加载参数的name_scope保持一致(图模型一致)
tf.train.Saver([var for var in tf.global_variables() if var.name.startswith('train')]) \
.restore(sess,' D:\\tuxiang\\hhh\\my_test_model-1')
来加载一个已构好图的网络对应的模型的全部参数,并可参与训练,我在写LSTM的时候用过上述代码,但在关于w1,w2的初始化时未能实现。
2.使用tf.import_meta_graph(path)将在.meta文件中定义的网络载入到当前图,然后使用特restore恢复参数
import tensorflow as tf
import numpy as np
with tf.Session() as sess:
saver =tf.train.import_meta_graph('D:\\tuxiang\\hhh\\my_test_model-1.meta')
saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
print(sess.run('train/w1:0'))
由于使用了tf.train.import_meta_graph(),不必将之前的网络在此处重写一遍。通过这种方式导入训练好的模型会将模型的所有参数导入。
注:如果已经重新构建了网络,又把之前的图加载进来,即使用tf.import_meta_graph(),然后再去restore网络的参数,则预训练模型的参数是不会被加载进来的。
3.只加载指定变量
为了保存或加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或加载的变量。
比如在加载模型时候使用saver = tf.train.Saver([w1]),则只有变量w1会被加载进来:
import tensorflow as tf
import numpy as np
with tf.variable_scope('train'):
w1 = tf.get_variable('w1', shape = [2])
w2 = tf.get_variable( name='w2',shape=[2])
w3 = tf.get_variable( name='w3',shape=[2])
saver = tf.train.Saver([w1])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run('train/w1:0'))
print(sess.run('train/w2:0'))
saver.restore(sess,'D:\\tuxiang\\hhh\\my_test_model-1')
print(sess.run('train/w1:0'))
print(sess.run('train/w2:0'))
# 输出结果
[0.92955863 0.4762975 ]
[0.88360226 0.48021615]
INFO:tensorflow:Restoring parameters from D:\tuxiang\hhh\my_test_model-1
[-0.22390778 1.227742 ]
[0.88360226 0.48021615]
除了可以选取需要加载或保存的变量,tf.train.Saver还可以支持在保存或加载时给变量重新命名。
例如:声明的变量名称和模型中保存的不一样。
import tensorflow as tf
import numpy as np
# 声明的变量和模型中已保存变量的名称不同
w1 = tf.get_variable('w1_1', shape = [2])
w2 = tf.get_variable( name='w2_1',shape=[5])
w3 = tf.get_variable( name='w3_1',shape=[2])
# 如果直接使用 tf.train.Saver()来加载则会报变量找不到的错误
# 此时使用一个字典来直接重命名变量就可加载原来的模型了,这个
# 字典指定原来名称为"w1"的变量现在加载到w1中(名称为w1_1)
saver = tf.train.Saver({'w1':w1,"w2":w2})
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run('w1_1:0'))
print(sess.run('w2_1:0'))
saver.restore(sess,'D:\\tuxiang\\hhh\\new\\my_test_model-1')
print(sess.run('w1_1:0'))
print(sess.run('w2_1:0'))
*************output********************
[0.03663075 0.6752244 ]
[-0.7689313 0.74632037 0.6029091 -0.01121217 0.64495254]
INFO:tensorflow:Restoring parameters from D:\tuxiang\hhh\new\my_test_model-1
[ 1.3500427 -0.08136963]
[-0.17664449 0.6879551 1.0955236 -1.7334721 1.4500109 ]
参考资料:
1.tensorflow从已经训练好的模型中,恢复(指定)权重(构建新变量、网络)并继续训练(finetuning)
2.TensorFlow中tf.train.Saver类说明.