一、保存Tensorflow模型:
1.保存文件说明
Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等。所以,Tensorflow模型有两个主要的文件:
1) graph.pbtxt:这其实是一个文本文件,保存了模型的结构信息
2) checkpoint 文件:其实就是一个txt文件,存储的是路径信息
3) .ckpt-*.meta: 其实和上面的graph.pbtxt作用一样都保存了graph结构,只不过meta文件是二进制的
4).ckpt-*.index: 这是一个string-string table,table的key值为tensor名,value为serialized BundleEntryProto。每个BundleEntryProto表述了tensor的metadata,比如那个data文件包含tensor、文件中的偏移量、一些辅助数据等。
5)model.ckpt-*.data-*: 保存了模型的所有变量的值,TensorBundle集合。
6)events.out.tfevents.*...: 保存的就是你的accuracy或者loss在不同时刻的值,是Tensorboard需要的。
2.保存代码说明
为了保存Tensorflow中的图和所有参数的值,我们创建一个tf.train.Saver()类的实例。
如果我们没有在tf.train.Saver()中指定任何参数,它会保存所有变量。如果我们不想保存全部变量而只是想保存一部分的话,我们可以指定想保存的variables/collections.在创建tf.train.Saver实例时,我们将它传递给我们想要保存的变量的列表或字典。
#保存全部变量
saver = tf.train.Saver()
#保存部分变量
vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')#获取指定scope的tensor
saver = tf.train.Saver(vgg_ref_vars)#初始化saver时,传入一个var_list的参数
#保存部分变量
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
保存模型的方法save:
- session是session对象
- model_savedpath="path+name" 是你对自己模型的"路径+命名"
- global_step=num表示迭代多少次就保存模型(比如每迭代1000次后保存模型:global_step=1000)
- max_to_keep=m ,如果你想保存最近的m个模型
- keep_checkpoint_every_n_hours=n,每训练n个小时保存一次
- write_meta_graph=False 不写入网络结构图
saver.save(session, "model_savedpath", global_step=epoch)
注意哦!变量是存在于Session环境中,也就是说,只有在Session环境下才会存有变量值
当使用Supervisor来管理时,如何保存:
sv = tf.train.Supervisor(logdir=log_path, init_op=init) # logdir用来保存checkpoint和summary saver = sv.saver # 创建saver
当使用MonitoredTrainingSession来管理时,如何保存:
使用MonitoredTrainingSession()之前,必须定义global_step变量
global_step = tf.train.get_or_create_global_step()
checkpoint_step = tf.assign_add(global_step, 1)
# 2秒保存一次检查点
save_filename = 'log/checkpoints'
sess = tf.train.MonitoredTrainingSession(checkpoint_dir=save_filename, save_checkpoint_secs=2)
控制checkpoint 数量scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=1))
scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=1))
self.sess = tf.train.MonitoredTrainingSession(
master=self.server.target,
checkpoint_dir=self.ckpt_dir,
save_checkpoint_secs=30,
is_chief=(self.task_index == 0),
scaffold=scaffold,
hooks=hooks,
config=self.conf)
二、模型加载
加载模型及变量说明
1. 全部加载的代码包括两个部分,加载网络结构和加载变量参数
(1)tf.train.import_meta_graph(path+"xxx.meta") 加载网络结构
(2)restore(path+"xxx/" )方法加载变量 #path+"xxx/" 指的是保存的模型路径,会自动找到最近保存的变量文件。需要前面训练好的模型参数(即weights、biases等),变量值需要依赖于Session,因此在加载参数时,先要构造好Session:
#加载模型结构
saver = tf.train.import_meta_graph(path+'xxx/yyy.meta')
#加载变量数据 使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型
#path+"xxx/" 指的是保存的模型路径。
saver.restore(sess, tf.train.latest_checkpoint(path+"xxx/"))
2. 若加载变量只想读取其中一部分变量值
reader = tf.train.NewCheckpointReader(checkpoint_path)
(1)通过 var = reader.get_variable_to_shape_map() 获取所有的变量
(2)通过graph.get_tensor_by_name("变量名")方法,引用保存"变量名"对应的值
def read_checkpoint():
w = []
checkpoint_path = 'path'
reader = tf.train.NewCheckpointReader(checkpoint_path)
var = reader.get_variable_to_shape_map()
for key in var:
if 'weights' in key and 'conv' in key and 'Mo' not in key:
print('tensorname:', key)
# # print(reader.get_tensor(key))
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#部分变量恢复
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.restore(sess, model_filename)
对于未被初始化的参数需要手动进行初始化
var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())
3.进行Fine-tune
通过tf.stop_gradient()方法进行截断反传
# pre-train and fine-tuning
fc2 = graph.get_tensor_by_name("fc2/add:0")
fc2 = tf.stop_gradient(fc2) # 将模型的一部分进行冻结
fc2_shape = fc2.get_shape().as_list()
# fine -tuning
new_nums = 6
weights = tf.Variable(tf.truncated_normal([fc2_shape[1], new_nums], stddev=0.1), name="w")
biases = tf.Variable(tf.constant(0.1, shape=[new_nums]), name="b")
conv2 = tf.matmul(fc2, weights) + biases
output2 = tf.nn.softmax(conv2)
参考网址:
1. Tensorflow加载预训练模型和保存模型_huachao1001的专栏-CSDN博客_tensorflow保存和加载模型
2.关于Tensorflow模型的保存、加载和预导入_YQ8023family的博客-CSDN博客
3. Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning_loveliuzz的博客-CSDN博客_ckpt文件
4.Tensorflow中保存与恢复模型tf.train.Saver类讲解(恢复部分模型参数的方法)_mieleizhi0522的博客-CSDN博客_saver.restore()恢复部分