Tensorflow模型保存、加载和fine-tune

一、保存Tensorflow模型:

1.保存文件说明

Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等。所以,Tensorflow模型有两个主要的文件

1) graph.pbtxt:这其实是一个文本文件,保存了模型的结构信息

2) checkpoint 文件:其实就是一个txt文件,存储的是路径信息

3) .ckpt-*.meta: 其实和上面的graph.pbtxt作用一样都保存了graph结构,只不过meta文件是二进制的

4).ckpt-*.index: 这是一个string-string table,table的key值为tensor名,value为serialized BundleEntryProto。每个BundleEntryProto表述了tensor的metadata,比如那个data文件包含tensor、文件中的偏移量、一些辅助数据等。

5)model.ckpt-*.data-*: 保存了模型的所有变量的值,TensorBundle集合。

6)events.out.tfevents.*...: 保存的就是你的accuracy或者loss在不同时刻的值,是Tensorboard需要的。

2.保存代码说明

为了保存Tensorflow中的图和所有参数的值,我们创建一个tf.train.Saver()类的实例。

如果我们没有在tf.train.Saver()中指定任何参数,它会保存所有变量。如果我们不想保存全部变量而只是想保存一部分的话,我们可以指定想保存的variables/collections.在创建tf.train.Saver实例时,我们将它传递给我们想要保存的变量的列表或字典

#保存全部变量
saver = tf.train.Saver()  

#保存部分变量
vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')#获取指定scope的tensor
saver = tf.train.Saver(vgg_ref_vars)#初始化saver时,传入一个var_list的参数

#保存部分变量
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值

保存模型的方法save:

  • session是session对象
  • model_savedpath="path+name" 是你对自己模型的"路径+命名"
  • global_step=num表示迭代多少次就保存模型(比如每迭代1000次后保存模型:global_step=1000)
  • max_to_keep=m ,如果你想保存最近的m个模型
  • keep_checkpoint_every_n_hours=n,每训练n个小时保存一次
  • write_meta_graph=False 不写入网络结构图
saver.save(session, "model_savedpath", global_step=epoch)

注意哦!变量是存在于Session环境中,也就是说,只有在Session环境下才会存有变量 

当使用Supervisor来管理时,如何保存:

sv = tf.train.Supervisor(logdir=log_path, init_op=init) # logdir用来保存checkpoint和summary saver = sv.saver # 创建saver 

当使用MonitoredTrainingSession来管理时,如何保存: 

使用MonitoredTrainingSession()之前,必须定义global_step变量
global_step = tf.train.get_or_create_global_step()
checkpoint_step = tf.assign_add(global_step, 1)
# 2秒保存一次检查点
save_filename = 'log/checkpoints'
sess = tf.train.MonitoredTrainingSession(checkpoint_dir=save_filename, save_checkpoint_secs=2)

控制checkpoint 数量scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=1))

scaffold = tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=1))
self.sess = tf.train.MonitoredTrainingSession(
           master=self.server.target,
           checkpoint_dir=self.ckpt_dir,
           save_checkpoint_secs=30,
           is_chief=(self.task_index == 0),
           scaffold=scaffold,
           hooks=hooks,
           config=self.conf)


二、模型加载

加载模型及变量说明

1. 全部加载的代码包括两个部分,加载网络结构加载变量参数

(1)tf.train.import_meta_graph(path+"xxx.meta") 加载网络结构

(2)restore(path+"xxx/" )方法加载变量 #path+"xxx/" 指的是保存的模型路径,会自动找到最近保存的变量文件。需要前面训练好的模型参数(即weights、biases等),变量值需要依赖于Session,因此在加载参数时,先要构造好Session

#加载模型结构  
saver = tf.train.import_meta_graph(path+'xxx/yyy.meta')
#加载变量数据  使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型
#path+"xxx/" 指的是保存的模型路径。
saver.restore(sess, tf.train.latest_checkpoint(path+"xxx/"))

 2. 若加载变量只想读取其中一部分变量值

  reader = tf.train.NewCheckpointReader(checkpoint_path)
 (1)通过 var = reader.get_variable_to_shape_map() 获取所有的变量

 (2)通过graph.get_tensor_by_name("变量名")方法,引用保存"变量名"对应的值

def read_checkpoint():
  w = []
  checkpoint_path = 'path'
  reader = tf.train.NewCheckpointReader(checkpoint_path)
  var = reader.get_variable_to_shape_map()
  for key in var:
    if 'weights' in key and 'conv' in key and 'Mo' not in key:
      print('tensorname:', key)
  #   # print(reader.get_tensor(key))

  op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#部分变量恢复
weight=[weights['wc1'],weights['wc2'],weights['wc3a']]
saver = tf.train.Saver(weight)#创建一个saver对象,.values是以列表的形式获取字典值
saver.restore(sess, model_filename)

 对于未被初始化的参数需要手动进行初始化

var = tf.get_variable(name, shape, initializer=tf.contrib.layers.xavier_initializer())

3.进行Fine-tune

通过tf.stop_gradient()方法进行截断反传

# pre-train and fine-tuning
fc2 = graph.get_tensor_by_name("fc2/add:0")
fc2 = tf.stop_gradient(fc2)  # 将模型的一部分进行冻结
fc2_shape = fc2.get_shape().as_list()
# fine -tuning
new_nums = 6
weights = tf.Variable(tf.truncated_normal([fc2_shape[1], new_nums], stddev=0.1), name="w")
biases = tf.Variable(tf.constant(0.1, shape=[new_nums]), name="b")
conv2 = tf.matmul(fc2, weights) + biases
output2 = tf.nn.softmax(conv2)
 

参考网址:

1. Tensorflow加载预训练模型和保存模型_huachao1001的专栏-CSDN博客_tensorflow保存和加载模型

2.关于Tensorflow模型的保存、加载和预导入_YQ8023family的博客-CSDN博客

3. Tensorflow加载预训练模型和保存模型(ckpt文件)以及迁移学习finetuning_loveliuzz的博客-CSDN博客_ckpt文件

4.Tensorflow中保存与恢复模型tf.train.Saver类讲解(恢复部分模型参数的方法)_mieleizhi0522的博客-CSDN博客_saver.restore()恢复部分

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值