1.前言
Tensorflow记录笔记,本文关于tensorflow模型训练的保存问题,加载已训练模型以继续训练,以及加载已训练模型,并且对模型进行改动微训练的问题。
2.模型保存
2.1 tf.train.Saver()
使用saver = tf.train.Saver() 接口进行保存,基本参数如下:
val_list: 启用该list代表保存这个list参数的权重,在图中,非list的参数不会被保存。
max_to_keep: 保存多少个最新的checkpoint文件,默认保存五个文件。
keep_checkpoint_every_n_hours: 多久保存checkpoint文件,单位是小时,
使用saver.save() 函数进行模型的保存,基本参数如下:
sess:当前建立的图会话,new_graph。
save_path: 是保存模型的路劲问题。
用例:
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(op)
for i in range(epoch):
sess.run(train_op)
saver.save(sess,'path/name',global_step=i)
生成
checkpoint
MyModel.data-00000-of-00001
MyModel.index
MyModel.meta
3.模型读取
3.1构建原始的图
在训练过程中,可以直接把训练模型的结构以及op都拷贝过来*保证图的结构相同,进行原图的构建,然后通过
#path是model的路径
'''
此处已经完成了tensorflow图的构建
'''
saver=tf.train.Saver()
with tf.Session() as sess:
latest=tf.train.latest_checkpoint(checkpoint_dir=pathdir) #pathdir是保存的目录,并不是文件的路径,该行可获取目录下最新的model的路径
saver.restore(sess, latest)
在check_point文件夹内是该内容,保存了model的模型名字,最后的数字对应的是训练到哪一步长。
3.2 断点续连
通过读取文件名字的形式确实模型训练到哪个step,再加载模型在该基础上进行训练。
last_step=2000
saver=tf.train.Saver(max_to_keep=2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
os.makedirs(args.checkpoint_dir,exist_ok=True)
if os.listdir(args.checkpoint_dir) !=[]:
print("model exist,and get start_step")
latest_path=tf.train.latest_checkpoint(args.checkpoint_dir )
start_step=int(latest_path.split("-")[-1]) #获取重新训练的时候的初始步长
if start_step:
print(f"start training from {start_step}.")
else:
print("error happen")
return -1
saver.restore(sess,latest_path)
else:
print("no model")
#saver_uni.restore(sess,tf.train.latest_checkpoint(args.load_dir))
start_step=0
save_path=os.path.join("./model_dir","model.ckpt")
for i in range(start_step+1,last_step+1):
sess.run(train_op)
if i%2000==0: #每2000step保存一次,并且只保存两个模型
saver.save(sess,save_path, global_step=i)
print(sess.run(train_loss))