很简单,在原代码基础上加几句代码即可
一、保存模型
3句即可,首先建立一个saver和一个路径
下面两句添加在session创建之前,参数、网络结构定义之后
# (C)pengchengIT 2021
#保存模型
saver = tf.train.Saver()
model_path = "model/001model"
#启动session
with tf.session() as sess:
...
然后调用save,自动将session中的参数保存起来。
下面一句添加在session最后面
save_path = saver.save(sess, model_path)
print("Model saved in file: %s" %save_path)
运行代码,会在代码文件的同级目录下找到model文件夹,其中有4个文件夹
二、读取模型
也是3句,在创建上述模型后,就可以直接用上次训练好的模型,直接对测试集操作了
注意:读取模型前,源代码中参数、网络结构的定义不变,保留不动,可以把训练时的session删去,创建新的session,读取模型后,所有造作在新的session里,其实只有session变了
同样要建立一个saver和一个路径(之前保存的地方),恢复模型这句添加在session创建好后
saver = tf.train.Saver()
model_path = "model/001model"
#启动session
with tf.session() as sess:
sess.run(tf.global_variables_initializer())
#回复模型变量
saver.restore(sess, model_path)