一个快速完整的教程来保存和恢复Tensorflow
http://cv-tricks.com/tensorflow-tutorial/-saverestore-tensorflow-models-quick-complete-tutorial/
在这个Tensorflow教程中,我将解释:
- Tensorflow模型是怎样的?
- 如何保存Tensorflow模型?
- 如何恢复预测/传输学习的Tensorflow模型?
- 如何使用导入的预训练模型进行微调和修改
本教程假定您对训练神经网络有一些想法。否则,请按照本教程进行操作并返回此处。
1.什么是Tensorflow模型?
在训练完神经网络之后,您需要将其保存以供将来使用并部署到生产环境。那么,什么是Tensorflow模型?Tensorflow模型主要包含我们已经培训的网络设计或者图形和网络参数的值。因此,Tensorflow模型有两个主要文件:
a)元图:
这是一个保存完整Tensorflow图的协议缓冲区; 即所有变量,操作,集合等。该文件具有.meta扩展名。
b)检查点文件:
这是一个二进制文件,其中包含权重,偏差,梯度和所有其他变量的所有值。这个文件有一个扩展名。CKPT。但是,Tensorflow已经从版本0.11改变了这一点。现在,而不是单个.ckpt文件,我们有两个文件:
1
2
3
|
mymodel
.
data
-
00000
-
of
-
00001
mymodel
.
index
|
.data文件是包含我们的训练变量的文件,我们将继续。
除此之外,Tensorflow还有一个名为checkpoint的文件,它只保存最新检查点文件的记录。
因此,总而言之,版本大于0.10的Tensorflow模型如下所示:
而0.11之前的Tensorflow模型仅包含三个文件:
1
2
3
4
|
inception_v1
.
meta
inception_v1
.
ckpt
checkpoint
|
现在我们知道了Tensorflow模型的外观,我们来学习如何保存模型。
2.保存Tensorflow模型:
假设您正在训练用于图像分类的卷积神经网络。作为一种标准做法,您需要关注损失和准确性数字。一旦您看到网络已经融合,您可以手动停止训练,或者您将运行固定数量的时期训练。培训完成后,我们希望将所有变量和网络图保存到一个文件以供将来使用。因此,在Tensorflow中,您想要保存要为其创建tf.train.Saver()类实例的所有参数的图形和值。
saver = tf.train.Saver()
请记住,Tensorflow变量只在会话中存在。因此,您必须通过调用刚创建的保存程序对象上的save方法将模型保存在会话中。
1
2
|
saver
.
save
(
sess
,
'my-test-model'
)
|
这里,sess是会话对象,而'my-test-model'是你想要给你的模型的名字。我们来看一个完整的例子:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
import
tensorflow
as
tf
w1
=
tf
.
Variable
(
tf
.
random_normal
(
shape
=
[
2
]
)
,
name
=
'w1'
)
w2
=
tf
.
Variable
(
tf
.
random_normal
(
shape
=
[
5
]
)
,
name
=
'w2'
)
saver
=
tf
.
train
.
Saver
(
)
sess
=
tf
.
Session
(
)
sess
.
run
(
tf
.
global_variables_initializer
(
)
)
saver
.
save
(
sess
,
'my_test_model'
)
# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint
|
如果我们在1000次迭代后保存模型,我们将通过传递步数来调用save:
saver.save(sess, 'my_test_model',global_step=1000)
这只会将'-1000'附加到型号名称上,并创建以下文件:
1
2
3
4
5
|
my_test_model
-
1000.index
my_test_model
-
1000.meta
my_test_model
-
1000.data
-
00000
-
of
-
00001
checkpoint
|
比方说,在训练时,我们在每1000次迭代后保存模型,所以.meta文件是第一次创建(在第1000次迭代中),我们不需要每次都重新创建.meta文件(所以,将.meta文件保存为2000,3000 ..或任何其他迭代)。我们只保存模型以进一步迭代,因为图形不会改变。因此,当我们不想编写元图时,我们使用这个:
1
2
|
saver
.
save
(
sess
,
'my-model'
,
global_step
=
step
,
write_meta_graph
=
False
)
|
如果您只想保留4个最新型号,并且想要在训练期间每2小时保存一个型号,则可以使用max_to_keep和keep_checkpoint_every_n_hours。
1
2
3
|
#saves a model every 2 hours and maximum 4 latest models are saved.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
4
,
keep_checkpoint_every_n_hours
=
2
)
|
请注意,如果我们没有在tf.train.Saver()中指定任何内容,它会保存所有变量。如果我们不想保存所有的变量而只保存其中的一部分,会怎样呢?我们可以指定我们想要保存的变量/集合。在创建tf.train.Saver实例时,我们将它传递给我们想要保存的变量的列表或字典。我们来看一个例子:
1
2
3
4
5
6
7
8
|
import
tensorflow
as
tf
w1
=
tf
.
Variable
(
tf
.
random_normal
(
shape
=
[
2
]
)
,
name
=
'w1'
)
w2
=
tf
.
Variable
(
tf
.
random_normal
(
shape
=
[
5
]
)
,
name
=
'w2'
)
saver
=
tf
.
train
.
Saver
(
[
w1
,
w2
]
)
sess
=
tf
.
Session
(
)
sess
.
run
(
tf
.
global_variables_initializer
(
)
)
saver
.
save
(
sess
,
'my_test_model'
,
global_step
=
1000
)
|
这可用于在需要时保存Tensorflow图的特定部分。
3.导入预先训练的模型:
如果你想使用别人的预先训练好的模型进行微调,你需要做两件事情:
a)创建网络:
您可以通过编写Python代码来创建网络,以手动创建每个图层作为原始模型。然而,如果你仔细想想,我们已经将网络保存在.meta文件中,我们可以使用tf.train.import()函数来重新创建网络,如下所示:saver = tf.train.import_meta_graph('my_test_model-1000.meta')
请记住,import_meta_graph会将.meta文件中定义的网络附加到当前图形中。因此,这将为您创建图形/网络,但我们仍然需要加载我们在此图上训练过的参数的值。
b)加载参数:
我们可以通过调用该保存程序中的恢复来恢复网络的参数,该程序是tf.train.Saver()类的一个实例。
1
2
3
4
|
with
tf
.
Session
(
)
as
sess
:
new_saver
=
tf
.
train
.
import_meta_graph
(
'my_test_model-1000.meta'
)
new_saver
.
restore
(
sess
,
tf
.
train
.
latest_checkpoint
(
'./'
)
)
|
在此之后,像w1和w2这样的张量值已经恢复并可以被访问:
1
2
3
4
5
6
|
with
tf
.
Session
(
)
as
sess
:
saver
=
tf
.
train
.
import_meta_graph
(
'my-model-1000.meta'
)
saver
.
restore
(
sess
,
tf
.
train
.
latest_checkpoint
(
'./'
)
)
print
(
sess
.
run
(
'w1:0'
)
)
##Model has been restored. Above statement will print the saved value of w1.
|
所以,现在您已经了解了Tensorflow模型的保存和导入工作原理。在下一节中,我已经描述了上述的实际用法来加载任何预先训练好的模型。
4.使用恢复的模型
既然您已经了解了如何保存和恢复Tensorflow模型,那么让我们开发一个实用指南,以恢复任何预先训练好的模型,并将其用于预测,微调或进一步培训。无论何时使用Tensorflow,您都可以定义一个图表,其中包含示例(训练数据)和一些超参数,例如学习速率,全局步长等。使用占位符提供所有训练数据和超参数是一种标准做法。我们使用占位符构建一个小型网络并保存它。请注意,保存网络时,不会保存占位符的值。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
import
tensorflow
as
tf
#Prepare to feed input, i.e. feed_dict and placeholders
w1
=
tf
.
placeholder
(
"float"
,
name
=
"w1"
)
w2
=
tf
.
placeholder
(
"float"
,
name
=
"w2"
)
b1
=
tf
.
Variable
(
2.0
,
name
=
"bias"
)
feed_dict
=
{
w1
:
4
,
w2
:
8
}
#Define a test operation that we will restore
w3
=
tf
.
add
(
w1
,
w2
)
w4
=
tf
.
multiply
(
w3
,
b1
,
name
=
"op_to_restore"
)
sess
=
tf
.
Session
(
)
sess
.
run
(
tf
.
global_variables_initializer
(
)
)
#Create a saver object which will save all the variables
saver
=
tf
.
train
.
Saver
(
)
#Run the operation by feeding input
print
sess
.
run
(
w4
,
feed_dict
)
#Prints 24 which is sum of (w1+w2)*b1
#Now, save the graph
saver
.
save
(
sess
,
'my_test_model'
,
global_step
=
1000
)
|
现在,当我们想要恢复它时,我们不仅需要恢复图形和权重,还要准备一个新的feed_dict,将新的训练数据馈送到网络。我们可以通过graph.get_tensor_by_name()方法获得对这些保存的操作和占位符变量的引用。
1
2
3
4
5
6
|
#How to access saved variable/Tensor/placeholders
w1
=
graph
.
get_tensor_by_name
(
"w1:0"
)
## How to access saved operation
op_to_restore
=
graph
.
get_tensor_by_name
(
"op_to_restore:0"
)
|
如果我们只想用不同的数据运行同一个网络,只需将新数据通过feed_dict传递给网络即可。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
import
tensorflow
as
tf
sess
=
tf
.
Session
(
)
#First let's load meta graph and restore weights
saver
=
tf
.
train
.
import_meta_graph
(
'my_test_model-1000.meta'
)
saver
.
restore
(
sess
,
tf
.
train
.
latest_checkpoint
(
'./'
)
)
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph
=
tf
.
get_default_graph
(
)
w1
=
graph
.
get_tensor_by_name
(
"w1:0"
)
w2
=
graph
.
get_tensor_by_name
(
"w2:0"
)
feed_dict
=
{
w1
:
13.0
,
w2
:
17.0
}
#Now, access the op that you want to run.
op_to_restore
=
graph
.
get_tensor_by_name
(
"op_to_restore:0"
)
print
sess
.
run
(
op_to_restore
,
feed_dict
)
#This will print 60 which is calculated
#using new values of w1 and w2 and saved value of b1.
|
如果您想通过添加更多图层并添加更多图层来为图表添加更多操作。当然你也可以这样做。看这里:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
import
tensorflow
as
tf
sess
=
tf
.
Session
(
)
#First let's load meta graph and restore weights
saver
=
tf
.
train
.
import_meta_graph
(
'my_test_model-1000.meta'
)
saver
.
restore
(
sess
,
tf
.
train
.
latest_checkpoint
(
'./'
)
)
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph
=
tf
.
get_default_graph
(
)
w1
=
graph
.
get_tensor_by_name
(
"w1:0"
)
w2
=
graph
.
get_tensor_by_name
(
"w2:0"
)
feed_dict
=
{
w1
:
13.0
,
w2
:
17.0
}
#Now, access the op that you want to run.
op_to_restore
=
graph
.
get_tensor_by_name
(
"op_to_restore:0"
)
#Add more to the current graph
add_on_op
=
tf
.
multiply
(
op_to_restore
,
2
)
print
sess
.
run
(
add_on_op
,
feed_dict
)
#This will print 120.
|
但是,您是否可以恢复部分旧图形和插件以进行微调?当然,您可以通过graph.get_tensor_by_name()方法访问相应的操作,并在其上创建图形。这是一个真实世界的例子。在这里,我们使用元图加载一个vgg预训练网络,并将最后一层中的输出数量更改为2,以便用新数据进行微调。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
.
.
.
.
.
.
.
.
.
.
.
.
saver
=
tf
.
train
.
import_meta_graph
(
'vgg.meta'
)
# Access the graph
graph
=
tf
.
get_default_graph
(
)
## Prepare the feed_dict for feeding data for fine-tuning
#Access the appropriate output for fine-tuning
fc7
=
graph
.
get_tensor_by_name
(
'fc7:0'
)
#use this if you only want to change gradients of the last layer
fc7
=
tf
.
stop_gradient
(
fc7
)
# It's an identity function
fc7_shape
=
fc7
.
get_shape
(
)
.
as_list
(
)
new_outputs
=
2
weights
=
tf
.
Variable
(
tf
.
truncated_normal
(
[
fc7_shape
[
3
]
,
num_outputs
]
,
stddev
=
0.05
)
)
biases
=
tf
.
Variable
(
tf
.
constant
(
0.05
,
shape
=
[
num_outputs
]
)
)
output
=
tf
.
matmul
(
fc7
,
weights
)
+
biases
pred
=
tf
.
nn
.
softmax
(
output
)
# Now, you run this with fine-tuning data in sess.run()
|