Tensorflow-模型的保存、恢复以及fine-tune

最近在做的本科毕设需要对迁移过来的ConvNet进行在线fine-tune,由于之前在学习Tensorflow时对模型的保存恢复学习的不够深入,所以今天花了一个下午看了几篇文章,觉得有的写的很不错,就搬运过来。在最后有自己的总结。

参考原文:https://zhuanlan.zhihu.com/p/53814653

使用tensorflow的过程中,我们常常会用到训练好的模型。我们可以直接使用训练好的模型进行测试或者对训练好的模型做进一步的微调。(微调是指初始化网络参数的时候不再是随机初始化,而是使用先前训练好的权重参数进行初始化,在此基础上对网络的全部或者局部参数进行重新训练的过程)。为了实现模型的复用或微调,我将从以下四个方面进行说明:

  • 模型是指什么?
  • 如何保存模型?
  • 如何恢复模型?
  • 如何进行微调?

一、模型是指什么?

tensorflow训练后需要保存的模型主要包含两部分,一是网络图,二是网络图里的参数值。保存的模型文件结构如下(假设每过1000次保存一次):

checkpoint
MyModel-1000.meta
MyModel-1000.data-00000-of-00001
MyModel-1000.index
MyModel-2000.meta
MyModel-2000.data-00000-of-00001
MyModel-2000.index
MyModel-3000.meta
MyModel-3000.data-00000-of-00001
MyModel-3000.index
.......

1 checkpoint

checkpoint是一个文本文件,如下所示。其中有model_checkpoint_path和all_model_checkpoint_paths两个属性。model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型文件的文件名。

model_checkpoint_path: "MyModel-3000"
all_model_checkpoint_paths: "MyModel-1000"
all_model_checkpoint_paths: "MyModel-2000"
all_model_checkpoint_paths: "MyModel-3000"
......

2 .meta文件

.meta 文件用于保存网络结构,且以 protocol buffer 格式进行保存。protocol buffer是Google 公司内部使用的一种轻便高效的数据描述语言。类似于XML能够将结构化数据序列化,protocol buffer也可用于序列化结构化数据,并将其用于数据存储、通信协议等方面。相较于XML,protocol buffer更小、更快、也更简单。划重点:网络结构,仅仅是网络结构

3 .data-00000-of-00001 文件和 .index 文件

在tensorflow 0.11之前,保存的文件结构如下。tensorflow 0.11之后,将ckpt文件拆分为了.data-00000-of-00001 和 .index 两个文件。.ckpt是二进制文件,保存了所有变量的值及变量的名称。拆分后的.data-00000-of-00001 保存的是变量值,.index文件保存的是.data文件中数据和 .meta文件中结构图之间的对应关系(也就是变量的名称)划重点:变量!tf.Variable()

checkpoint
MyModel.meta
MyModel.ckpt

二、如何保存模型?

tensorflow 提供tf.train.Saver类及tf.train.Saver类下面的save方法共同保存模型。下面分别说明tf.train.Saver类及save方法:

tf.train.Saver(var_list=None, reshape=False, sharded=False, max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0, name=None, restore_sequentially=False, 
    saver_def=None, builder=None, defer_build=False, allow_empty=False,
    write_version=saver_pb2.SaverDef.V2, pad_step_number=False)
就常用的参数进行说明:
var_list:如果我们不对tf.train.Saver指定任何参数,默认会保存所有变量。如果你只想保存一部分变量,
          可以通过将需要保存的变量构造list或者dictionary,赋值给var_list。
max_to_keep:tensorflow默认只会保存最近的5个模型文件,如果你希望保存更多,可以通过max_to_keep来指定
keep_checkpoint_every_n_hours:设置每隔几小时保存一次模型


save(sess,save_path,global_step=None,latest_filename=None,meta_graph_suffix="meta",
      write_meta_graph=True, write_state=True)
就常用的参数进行说明:
sess:在tensorflow中,只有开启session时数据才会流动,因此保存模型的时候必须传入session。
save_path: 模型保存的路径及模型名称。
global_step:定义每隔多少步保存一次模型,每次会在保存的模型名称后面加上global_step的值作为后缀
write_meta_graph:布尔值,True表示每次都保存图,False表示不保存图(由于图是不变的,没必要每次都去保存)

注意:保存变量的时候必须在session中;保存的变量必须已经初始化;

1.简单示例

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
w3 = tf.Variable(tf.random_normal(shape=[1]), name='w3')
saver = tf.train.Saver()#未指定任何参数,默认保存所有变量。等价于saver = tf.train.Saver(tf.trainable_variables())
save_path = './checkpoint_dir/MyModel'#定义模型保存的路径./checkpoint_dir/及模型名称MyModel

with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	saver.save(sess, save_path)

执行后,在checkpoint_dir目录下创建模型文件如下:

checkpoint
MyModel.data-00000-of-00001
MyModel.index
MyModel.meta

2.经典示例

import tensorflow as tf
from six.moves import xrange
import os

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w11')#变量w1在内存中的名字是w11;恢复变量时应该与name的名字保持一致
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w22')
w3 = tf.Variable(tf.random_normal(shape=[5]), name='w33')

#保存一部分变量[w1,w2];只保存最近的5个模型文件;每2小时保存一次模型
saver = tf.train.Saver([w1, w2],max_to_keep=5, keep_checkpoint_every_n_hours=2)
save_path = './checkpoint_dir/MyModel'#定义模型保存的路径./checkpoint_dir/及模型名称MyModel

# Launch the graph and train, saving the model every 1,000 steps.
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	for step in xrange(100):
		if step % 10 == 0:
			# 每隔step=10步保存一次模型( keep_checkpoint_every_n_hours与global_step可同时使用,表示'与',通常任选一个就够了);
			#每次会在保存的模型名称后面加上global_step的值作为后缀
			# write_meta_graph=False表示不保存图
			saver.save(sess, save_path, global_step=step, write_meta_graph=False)
			# 如果模型文件中没有保存网络图,则使用如下语句保存一张网络图(由于网络图不变,只保存一次就行)
			if not os.path.exists('./checkpoint_dir/MyModel.meta'):
				# saver.export_meta_graph(filename=None, collection_list=None,as_text=False,export_scope=None,clear_devices=False)
				# saver.export_meta_graph()仅仅保存网络图;参数filename表示网络图保存的路径即网络图名称
				saver.export_meta_graph('./checkpoint_dir/MyModel.meta')#定义网络图保存的路径./checkpoint_dir/及网络图名称MyModel.meta
                                #注意:tf.train.export_meta_graph()等价于tf.train.Saver.export_meta_graph()

执行后,在checkpoint_dir目录下创建模型文件如下:

checkpoint
MyModel.meta
MyModel-50.data-00000-of-00001
MyModel-50.index
MyModel-60.data-00000-of-00001
MyModel-60.index
MyModel-70.data-00000-of-00001
MyModel-70.index
MyModel-80.data-00000-of-00001
MyModel-80.index
MyModel-90.data-00000-of-00001
MyModel-90.index

三、如何恢复模型?

tensorflow保存模型时将网络图和网络图里的参数值分开保存。因此,在恢复模型时,也要分为2步:构造网络图和加载参数。

模型的恢复分为两步,第一步是graph的重新构建,第二步是模型参数的加载。模型参数的加载对应的是变量的初始化。

1 构造网络图

构造网络图可以手动创建(需要创建一个跟保存的模型一模一样的网络图)

也可以从meta文件里加载graph进行创建,如下:

#首先恢复graph
saver = tf.train.import_meta_graph('./checkpoint_dir/MyModel.meta')

2 恢复参数有两种方式,如下:

with tf.Session() as sess:
    #恢复最新保存的权重
    saver.restore(sess, tf.train.latest_checkpoint('./checkpoint_dir'))
    #指定一个权重恢复
    saver.restore(sess, './checkpoint_dir/MyModel-50')#注意不要加文件后缀名。若权重保存为.ckpt则需要加上后缀

 

 

四、如何进行微调?(*)

上面叙述了如何恢复模型,那么,对于恢复出来的模型应该如何使用呢?这里以tensorflow官网给出的vgg为例进行说明。下载地址

恢复出来的模型有四种用途:

  • 查看模型参数
  • 直接使用原始模型进行测试
  • 扩展原始模型(直接使用扩展后的网络进行测试,扩展后需要重新训练的情况见微调部分)
  • 微调:使用训练好的权重参数进行初始化,在此基础上对网络的全部或局部参数进行重新训练

1.查看模型参数

import tensorflow as tf
import vgg

# build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_classes=1000)

saver = tf.train.Saver()
with tf.Session() as sess:
	saver.restore(sess, './vgg_16.ckpt')  # 权重保存为.ckpt则需要加上后缀
	"""
	   查看恢复的模型参数
	   tf.trainable_variables()查看的是所有可训练的变量;
	   tf.global_variables()获得的与tf.trainable_variables()类似,只是多了一些非trainable的变量,比如定义时指定为trainable=False的变量;
	   sess.graph.get_operations()则可以获得几乎所有的operations相关的tensor
	   """
	tvs = [v for v in tf.trainable_variables()]
	print('获得所有可训练变量的权重:')
	for v in tvs:
		print(v.name)
		print(sess.run(v))
	
	gv = [v for v in tf.global_variables()]
	print('获得所有变量:')
	for v in gv:
		print(v.name, '\n')
	
	# sess.graph.get_operations()可以换为tf.get_default_graph().get_operations()
	ops = [o for o in sess.graph.get_operations()]
	print('获得所有operations相关的tensor:')
	for o in ops:
		print(o.name, '\n')

2.直接使用原始模型进行测试

import tensorflow as tf
import vgg
import numpy as np
import cv2

image = cv2.imread('./cat.18.jpg')
print(image.shape)
res = cv2.resize(image, (224,224))
res_image = np.expand_dims(res, 0)
print(res_image.shape, type(res_image))

#build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_classes=1000)
print(end_points)
saver = tf.train.Saver()

with tf.Session() as sess:
	#恢复权重
	saver.restore(sess, './vgg_16.ckpt')#权重保存为.ckpt则需要加上后缀
	
	# Get input and output tensors
        # 需要特别注意,get_tensor_by_name后面传入的参数,如果没有重复,需要在后面加上“:0”
        # sess.graph等价于tf.get_default_graph()
	input = sess.graph.get_tensor_by_name('inputs:0')
	output = sess.graph.get_tensor_by_name('vgg_16/fc8/squeezed:0')
	
	# Run forward pass to calculate pred
        #使用不同的数据运行相同的网络,只需将新数据通过feed_dict传递到网络即可。
	pred = sess.run(output, feed_dict={input:res_image})
	#得到使用vgg网络对输入图片的分类结果
	print(np.argmax(pred, 1))

3.扩展原始模型

import tensorflow as tf
import vgg
import numpy as np
import cv2

image = cv2.imread('./cat.18.jpg')
print(image.shape)
res = cv2.resize(image, (224, 224))
res_image = np.expand_dims(res, 0)
print(res_image.shape, type(res_image))

# build graph
graph = tf.Graph
inputs = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
net, end_points = vgg.vgg_16(inputs, num_classes=1000)
print(end_points)
saver = tf.train.Saver()

with tf.Session() as sess:
	# 恢复权重
	saver.restore(sess, './vgg_16.ckpt')  # 权重保存为.ckpt则需要加上后缀
	
	# 明确的网络的输入输出,通过get_tensor_by_name()获取变量
	input = sess.graph.get_tensor_by_name('inputs:0')
	output = sess.graph.get_tensor_by_name('vgg_16/fc8/squeezed:0')
	
	# add more operations to the graph
	# 这里只是简单示例,也可以加上新的网络层。
	pred = tf.argmax(output, 1)
	
	# 使用不同的数据运行扩展后的网络(这里扩展后的网络不涉及变量,可以直接使用扩展后的网络进行测试)
	pred = sess.run(pred, feed_dict={input: res_image})
	print(pred)

注意:

扩展的网络结构是可以在Session外定义的,但是需要在原计算图搭建好之后,因为这涉及到Graph中变量的名称。

4.微调

变量ensorflow as tf
import vgg
import numpy as np
import cv2
from skimage import io
import os

# -----------------------------------------准备数据--------------------------------------
#这里以单张图片作为示例,简单说明原理
image = cv2.imread('./cat.18.jpg')
print(image.shape)
res_image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_CUBIC)#vgg_16有全连接层,需要固定输入尺寸
print(res_image.shape)
res_image = np.expand_dims(res_image, axis=0)#网络输入为四维[batch_size, height, width, channels]
print(res_image.shape)
labels = [[1,0]]#标签

# -----------------------------------------恢复图------------------------------------------
#恢复图的方式有很多,这里采用手动构造一个跟保存权重时一样的graph
graph = tf.get_default_graph()

input = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')
y_ = tf.placeholder(dtype=tf.float32, shape=[None, 2], name='labels')

# net=[batch, 2]其中2表示二分类,注意官网给出的vgg_16最终的输出没有经过softmax层
net, end_points = vgg.vgg_16(input, num_classes=2)  # 保存的权重模型针对的num_classes=1000,这里改为num_classes=2,因此最后一层需要重新训练
print(net, end_points)  # net是网络的输出;end_points是所有变量的集合

#add more operations to the graph
y = tf.nn.softmax(net)  # 输出0-1之间的概率
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
output_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='vgg_16/fc8')  # 注意这里的scope是定义graph时 name_scope的名字,不要加:0
print(output_vars)

# loss只作用在var_list列表中的变量,也就是说只训练var_list中的变量,其余变量保持不变。若不指定var_list,则默认重新训练所有变量
train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy,var_list=output_vars)

# ----------------------------------------恢复权重------------------------------------------
var = tf.global_variables()  # 获取所有变量
print(var)
# var_to_restore = [val for val in var if 'conv1' in val.name or 'conv2' in val.name]#保留变量中含有conv1、conv2的变量
var_to_restore = [val for val in var if 'fc8' not in val.name]  # 保留变量名中不含有fc8的变量
print(var_to_restore)

saver = tf.train.Saver(var_to_restore)  # 恢复var_to_restore列表中的变量(最后一层变量fc8不恢复)

with tf.Session() as sess:
	# restore恢复变量值也是变量初始化的一种方式,对于没有restore的变量需要单独初始化
	# 注意如果使用全局初始化,则应在全局初始化后再调用saver.restore()。相当于先通过全局初始化赋值,再通过restore重新赋值。
	saver.restore(sess, './vgg_16.ckpt')  # 权重保存为.ckpt则需要加上后缀

	var_to_init = [val for val in var if 'fc8' in val.name]  # 保留变量名中含有fc8的变量

	# tf.variable_initializers(tf.global_variables())等价于tf.global_variables_initializer()
	sess.run(tf.variables_initializer(var_to_init))  # 没有restore的变量需要单独初始化
	# sess.run(tf.global_variables_initializer())

	# 用w1,w8测试权重恢复成功没有.正确的情况应该是:w1的值不变,w8的值随机
	w1 = sess.graph.get_tensor_by_name('vgg_16/conv1/conv1_1/weights:0')
	print(sess.run(w1, feed_dict={input: res_image}))

	w8 = sess.graph.get_tensor_by_name('vgg_16/fc8/weights:0')
	print('w8', sess.run(w8, feed_dict={input: res_image}))
	
	sess.run(train_op, feed_dict={input:res_image, y_:labels})

注意:

对未参与加载的变量也需要进行初始化。可以先对全部的变量进行初始化,之后再继续模型加载。注意二者顺序不可互换,因为模型加载相当于是参数初始化的过程,如果先加载模型之后又进行全部变量初始化,则会将已加载的模型参数覆盖成新的在当前graph中定义的初始化值。

因为将Placeholder和Operation都视为特殊的变量,只不过在ckpt中不保存他们的值,而constant又是作为Variable的一部分,其值必然会被保存。所以实际上ckpt文件会保存每一种数据类型,只不过有的并没有对数据值进行保存。所以通过自己重新搭建graph再由ckpt对参数值进行恢复是可行的。任意一种类型的数据都可以被恢复,只不过有的数据值(placeholder对应的数据值)需要通过feed_dict将数据喂入。

 

五、补充

1 .pb格式的文件

上面提到对于恢复的模型可以直接用来进行测试。对于不再需要改动的模型,我们可以将其保存为.pb格式的文件。

为什么要生成pb文件呢?简单来说就是直接通过tf.saver保存的模型文件其参数和图是分开的。这种形式方便对程序进行微小的改动。但是对于训练好,以后不再需要改动的模型这种形式就不是很必要了。

pb文件就是将变量的值固定下来,直接“烧”到图里面。这个时候只需用户提供一个输入,我们就可以通过模型得到一个输出给用户。pb文件一方面可提供给用户做离线的预测;另一方面,对于线上的模型,一般是通过C++或者C语言编写的程序进行调用。所以模型最终都是写成pb格式的文件。

2 .npy格式的文件

tensorflow保存的模型文件只能在tensorflow框架下使用,不利于将模型权重导入到其他框架使用,同时保存的模型文件无法直接查看。因此经常会考虑转换为.npy格式。.npy文件里的权重值是以数组的形式保存着的,方便查看。

 

总结

1.

Tensorflow中,关键的是Tensor和Opeation。是在计算图中的定义,他们才是本质。这里蕴涵着OOP的思想,对象本身才是关键,而指向其的标签,终究只是标签而已。

例如

input = tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs')

关键的是tf.placeholder(dtype=tf.float32, shape=[None, 224, 224, 3], name='inputs’)这个Tensor对象本身,而不是指向他的表情input。

2.

.meta文件保存的是graph的结构

.ckpt文件保存的是变量以及变量的值。值得注意的是,这里强调的是变量,tf.Variable(),所以这里不得不提到tensorflow中的几种Tensor类型:Variable、constant、placeholder。其中,Variable是正经的变量,他以及他的值会保存在.ckpt文件中毫无疑问;而constant,是常量,但是会发现很多时候tensorflow中的Variable初始化时都是使用constant,所以有理由认为Variable中的数据值都是以constant的形式被保存的,所以constant也会被保存在.ckpt文件中;placeholder,ckpt将他视为特殊的Variable,之所以特殊,是因为在.ckpt中会将其保存,但是由于他是一个占位符,所以他在原模型Graph中喂入的数据值并不会被保存,所以他就特殊在只保存了他这个变量,但是他的值并没有得到保存,所以在模型恢复以后,还是一样地需要使用feed_dict对其喂入数据,以此使得整个Graph运行起来。

还有一个就是Tensorflow中的operation,其依附于相关联的Tensor,所以也会被保存至.ckpt文件中,而其值是否被保存,就取决于operation所关联的Tensor是Variable、constant还是placeholder。

同时,对于要获取恢复后的某一Tensor,无论是Variable、constant、placeholder还是operation依附的Tensor,都可以使用graph.get_tensor_by_name来获取它的引用。

3.

问题:

关于模型恢复时,采用手动搭建网络,那么新搭建的graph中的变量Tensor名称是否需要和已训练好的模型中相应的Tensor的名称一致?

回答:

未知,应该是要的。只不过如果是手动搭建,名称会自动分配(有自己的一套命名规则),这可能也是要求手动搭建的graph要和已训练模型训练时的graph相一致的原因吧,只有一模一样的graph在自动命名时相应变量的名称才会一致。所以要搭建新网络结构,必须在手动搭建完原graph之后。

4.

强调Fine-Tune时只进行部分参数的更新的方法:

# loss只作用在var_list列表中的变量,也就是说只训练var_list中的变量,其余变量保持不变。若不指定var_list,则默认重新训练所有变量
train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy,var_list=output_vars)

5.

总结一下从ckpt文件中恢复模型的Fine-tune过程:

----------恢复图结构-------------

->恢复原本的模型

->修改模型,增加新的网络结构,对新变量使用scope变量管理机制,便于变量的管理。注意设置新的变量的参数Trainable=True

->获取需要进行Fine-Tune的变量,组成列表。通过使用tf.get_collections()来进行变量的获取。Keys和scope为所需传递的参数,通常Keys=tf.GraphsKeys.TRAINABLE_VARIABLES,表示所有可训练的变量。因为只有变量是可训练的,才有Fine-Tune的意义,或者说才可以进行Fine-Tune。而scope是对应的名称管理时说使用的名称。

题外话:获取任意变量的方式还有graph.get_tensor_by_name,但是所需要的参数是变量的完整名称,这点就相对困难。所以如果是要获取单个变量,则可以使用get_tensor_by_name,如果是要获取大量变量,则使用tf.get_collections更为方便。

-----------恢复值--------------

->将需要使用模型中存储变量值的变量根据模型存储的变量值进行恢复,通过指定变量来恢复模型来实现。模型变量值的恢复本质就是变量的初始化。再对不需要使用模型中存储变量值的变量、没有在模型中存储变量值的变量进行初始化。总之,就是要完成所有变量的初始化,无论是使用模型恢复的方式还是直接初始化的方式。

 

=============

若要进行fine-tune,则首先需要构建出要fine-tune的模型以及fine-tune的流程,即对某些变量进行的训练的过程。

(1)而对于fine-tune模型结构的构建,可以是从已保存的模型中,选择.meta文件,使用tf.train.import_meta_graph函数,直接恢复保存好的已训练的模型结构,也可以是自己手动根据保存的已训练好的模型进行模型结构的搭建。无论是使用哪种方式,都可以在Session外先将整个模型结构以及新添加的网络结构搭建好,之后再进入Session中run整个流程图,得到某个节点的Tensor。

对于第一种方式,Default Graph的获取可以在Session外进行。所以可以在Session外使用graph.get_tensor_by_name获取某个特定的节点,从而可以继续从某个节点延伸出去搭建新的网络结构。

在调用

saver = tf.train.import_meta_graph("Model/model_save.ckpt.meta")

时,就将该模型中的结构导入了当前default graph中了。所以之后的graph结构获取,可以通过tf.get_default_graph来实现。

# -*- coding:utf-8 -*-
import tensorflow as tf

saver = tf.train.import_meta_graph("Model/model_save.ckpt.meta")

# 获取当前默认图
# 如果只有一个会话时,会话中的图也就是默认的图
# 默认图的获取无需在Session中进行
graph = tf.get_default_graph()

# operations的名称可以使用graph.get_operations函数获取。
# ops1 = graph.get_operations()
# for i in ops1:
#     print i.name
conv5_1_add = graph.get_tensor_by_name('conv5_1/add:0')

# 随意添加的网络的新分支
c = conv5_1_add + tf.Variable([1.])

with tf.Session() as sess:
    pass

而对于第二种方式,则直接在手动的搭建的原模型的网络后加入新的网络结构即可。同样可以在Session外完成。

 

(2)模型搭建完成,接下来就是对模型变量数据的恢复。如果是使用第一种直接从.meta文件中恢复模型结构的方式,且同时需要对特定的模型变量进行数据的恢复(即变量初始化的一种方式),就需要在restore的时候,额外使用一个Saver类实例。以下是saver.restore()的官方注释:

""" Restores previously saved variables.

    This method runs the ops added by the constructor for restoring variables.
    It requires a session in which the graph was launched.  The variables to
    restore do not have to have been initialized, as restoring is itself a way
    to initialize variables.
    
    The `save_path` argument is typically a value previously returned from a
    `save()` call, or a call to `latest_checkpoint()`."""

可知,需要传入一个graph所在的Session。这个graph中包含了所需要恢复的变量所在的结构,所以额外使用的Saver类实例就是为了将graph中的变量数据恢复,即对这些变量进行初始化。一个Saver类实例在用于变量恢复时,创建时传入的参数需要是指定要恢复的变量,这时与Session和Graph都没有联系,也就是说,在用于恢复变量数据的Saver类实例的创建时,是没有指定针对哪个Graph和哪个Session的。同时在使用train.import_meta_graph后就会将模型设为default graph。之后再使用额外的Saver类实例,同时在声明时指定要恢复的变量,之后再在Session中将sess和变量数据所在的文件所在路径传递给这个类实例的restore方法。可以实现模型变量数据的恢复。

# -*- coding:utf-8 -*-
import tensorflow as tf

# saver1 for file .meta
saver1 = tf.train.import_meta_graph("Model/model_save.ckpt.meta")
global_variables = [v for v in tf.global_variables()]

graph = tf.get_default_graph()

# saver2 for specific Variables restore
saver2 = tf.train.Saver(global_variables)

#ops1 = graph.get_operations()
#for i in ops1:
#    print i.name
conv5_1_add = graph.get_tensor_by_name('conv5_1/add:0')
c = conv5_1_add + tf.Variable([1.])

with tf.Session() as sess:

    saver2.restore(sess, "Model/model_save.ckpt")
    ...

同时,还需要对未从保存的模型中恢复的变量的数据进行恢复,更准确地说是初始化。对于这些并非从模型中恢复的变量的初始化,就使用tf.variables_initializer(var_to_init)实现即可,其中var_to_init为需要初始化的变量列表。

var_to_init = [val for val in var if 'fc8' in tf.global_variables().name]  
# 保留变量名中含有fc8的变量

# tf.variable_initializers(tf.global_variables())
# 等价于tf.global_variables_initializer()
sess.run(tf.variables_initializer(var_to_init))  # 没有restore的变量需要单独初始化
# sess.run(tf.global_variables_initializer())

这里还有官方教程:

https://blog.csdn.net/k87974/article/details/80753094

 

 

 

 

  • 6
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值