在深度学习中,迁移学习是一个很普遍的操作,即将一个训练好的网络的一部分迁移到另一个网络,作为另一个网络结果的一部分.但是,我们要怎么保存和迁移呢?今天将以tensorflow的代码为例,给大家一个简单的介绍.
采用的函数是: tf.train.Saver()
1.存储和读取的步骤
(1)存储saver.save(sess, save_dir)
saver = tf.train.Saver()#声明ta.train.Saver()类用于保存
save_path = saver.save(sess,'save/filename.ckpt')#保存路径为相对路径的save文件夹,保存名为filename.ckpt
存储之后总共有几个后缀的文件:
filename.ckpt.meta:保存tensorflow的网络(计算图)结构
filename.ckpt:保存tensorflow中每一个变量的值
ckptpoint:保存一个目录下所有的模型文件列表
(2)读取saver.restore()
save.restore(sess, 'save/filename.ckpt')#从保存路径读取
在读取之前,先定义号和原来模型中相同的变量.读取出的结果直接赋值给变量使用
(3)直接测试已经训练好的模型
可以通过meta graph构建网络、载入训练时得到的参数,并使用默认的session:
saver = tf.train.import_meta_graph(‘save/filename.meta’)
saver.restore(tf.get_default_session(),’ save/filename.ckpt-16000’)
2.代码实现
代码实现我懒得写了,引用一个作者(Traphix)写好的,比较清晰明了: https://www.jianshu.com/p/83fa3aa2d0e9
(1)训练网络的
import