tensorflow存储、恢复神经网络结构和变量

tensorflow提供了变量保存接口,方便储存训练好的网络参数,以便进行预测和继续训练。

save核心代码只有两句话,简化DEMO如下(代码中的global_step后边会解释):

import tensorflow as tf

W1 = tf.Variable([[1,2,3],[4,5,6]])#, name = 'variable1'
W2 = tf.Variable([[11,22,33],[44,55,66]])#, name = 'variable2'
print(W2)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run(W1))
print(sess.run(W2))
#save
saver = tf.train.Saver()
saver.save(sess,'./my_model_saved',global_step = 0)

注意这里不能加name,会报错。

下边是restore的代码,注意,如果想要在同一个文件(jupyter日常)同时写这两段代码,一定要在中间加上重置默认graph,不然存在干扰,你以为你restore了变量,其实可能是前边定义的。

tf.reset_default_graph()

要点,变量名要一致,同样都是W1和W2,换名字不行。

注意看save和restore的文件路径是不一样的,这是因为我加了global_step,global_step在实际操作中是肯定会用到的,这个后缀方便你知道训练到了多少步,哪个文件进度更新,也可以随意选择想要加载的训练结果。不过本例并没有实际运行训练过程和更新global_step,实际操作是要变量自增的。(想要简化代码,把global_step参数去掉,并把restore参数中文件名的后缀去掉就行了。)


#reset graph and session,no need to delete code above.
tf.reset_default_graph()

W1 = tf.Variable([[1223,2,3],[433,5123,6]])
W2 = tf.Variable([[0,0,0],[0,0,0]])
print(W2)
sess = tf.Session()#after reset_default_graph(),a new session is necessary

#load
saver = tf.train.Saver()#after reset_default_graph(),a new saver is necessary
#saver.restore(sess,'./my_model_saved')#ValueError1--not a valid checkpoint

saver.restore(sess,'./my_model_saved-0')#ValueError1--solution1
#saver.restore(sess, './my_model_saved',global_step)#ValueError1--solution3--not support,no parameter named global_step

print(sess.run(W1))
print(sess.run(W2))

 运行结果:重新定义的W1和W2被restore成了前边存的值。

<tf.Variable 'Variable_1:0' shape=(2, 3) dtype=int32_ref>
[[1 2 3]
 [4 5 6]]
[[11 22 33]
 [44 55 66]]
<tf.Variable 'Variable_1:0' shape=(2, 3) dtype=int32_ref>
[[1 2 3]
 [4 5 6]]
[[11 22 33]
 [44 55 66]]

上边的代码的意思是你存了变量W1和W2,存储它们的值,然后在另一个网络中定义W1和W2,把前边存的值赋值给他们。这个功能和initializer是一样的,也就是说,你不需要再run一次initializer了

难道我一定要存所有变量且读取所有变量吗?也不是,其实Saver的参数列表有var_list的,可以选择性的存和选择性的读:

import tensorflow as tf
from tensorflow.python.tools import inspect_checkpoint as chkp

W1 = tf.Variable([[1,2,3],[4,5,6]])#, name = 'variable1'
W2 = tf.Variable([[11,22,33],[44,55,66]])#, name = 'variable2'
print(W2)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run(W1))
print(sess.run(W2))
#save
saver = tf.train.Saver(var_list = {'variable1':W1, 'variable2':W2})
saver.save(sess,'./my_model_savedd')

################################################################################################
#reset graph and session,no need to delete code above.
tf.reset_default_graph()

W1 = tf.Variable([[1223,2,3],[433,5123,6]])
W2 = tf.Variable([[0,0,0],[0,0,0]])
sess = tf.Session()
#if all_tensors is True,print variable2 too
chkp.print_tensors_in_checkpoint_file('./my_model_savedd', tensor_name = 'variable1', all_tensors = False)#True
#load
saver = tf.train.Saver(var_list = {'variable1':W1})
Reverse = False
if Reverse == False:
    sess.run(tf.global_variables_initializer())#this is necessary
    saver.restore(sess,'./my_model_savedd')
else:
    saver.restore(sess, './my_model_savedd')
    sess.run(tf.global_variables_initializer())  # this is necessary

print('restored W1:\n', sess.run(W1))
print('initialized W2:\n', sess.run(W2))


<tf.Variable 'Variable_1:0' shape=(2, 3) dtype=int32_ref>
[[1 2 3]
 [4 5 6]]
[[11 22 33]
 [44 55 66]]
tensor_name:  variable1
[[1 2 3]
 [4 5 6]]
restored W1:
 [[1 2 3]
 [4 5 6]]
initialized W2:
 [[0 0 0]
 [0 0 0]]

可以看到运行结果,W1和存的一样,W2就是默认的000000,是选择性的存取。

需要注意的点:存的var_list要永远不小于取的var_list,一旦你使用var_list指定名称,对应的另一边,无论是存还是取,都需要用相同的名称,这是字典操作的基本要求。var_list罩不到的variable,需要用initializer去初始化。

重要:如果你想用initializer来初始化其他未restore变量,一定要注意代码执行顺序,先执行initializer,先执行initializer,先执行initializer!initializer和restore不会智能互相识别让步,是相互覆盖的关系,这就好比你运行a=1又运行a=2。可以把Reverse改成True来验证错误的结果。

chkp.print_tensors_in_checkpoint_file是打印变量的方法,打印出来更加直观

 

 

saver只是存了变量,并不存储网络结构,所以很麻烦,你得自己重新定义网络结构(这里是W1和W2),重新写一段代码。

那么有没有办法存储网络结构?方法肯定是有的——tf.train.write_graph()


代码如下:

from tensorflow.python.platform import gfile
import tensorflow as tf
import os

W1 = tf.Variable([[1,2,3],[4,5,6]], name = 'var1')
W2 = tf.Variable([[11,22,33],[44,55,66]], name = 'var2')
print(W1)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

print(sess.run(W1))
print(sess.run(W2))

os.system('rm ./my_graph.pb')
tf.train.write_graph(sess.graph_def, './', 'my_graph.pb',False)

saver = tf.train.Saver(tf.global_variables())
os.system('rm ./my_model')
saver.save(sess,'./my_model',global_step = 0)

##############################################
tf.reset_default_graph()
###############################################
sess = tf.Session()#after reset_default_graph(),a new session is necessary

with gfile.FastGFile("./my_graph.pb", 'rb') as f:#"tmp/load/test.pb"
  graph_def = tf.GraphDef()
  graph_def.ParseFromString(f.read())
  sess.graph.as_default()
  tf.import_graph_def(graph_def,name='')

W1 = sess.graph.get_tensor_by_name("var1:0")#you have to get a graph firstly
W200 = sess.graph.get_tensor_by_name("var2:0")
YYY = sess.graph.get_tensor_by_name("var1:0")
print('before:', tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, W1)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, W200)
print('after:', tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

try:
  saver = tf.train.Saver(tf.global_variables())  # 'Saver' misnomer! Better: Persister!
except:
  pass
print("load data")

saver.restore(sess,'./my_model-0')

print(W1.eval(session=sess))
print(W200.eval(session=sess))

print(YYY.eval(session=sess))
print(W1)
print(YYY)



这里代码二合一,存取在同一个文件,同样不要忘了reset_default_graph()

操作分两部分:存取graph、存取variable。

注意看我在读取时用了变量W1和W200、YYY,不是W1和W2,也就是说Python变量名不受限了,赋值的变量数量也不受限。这就是存graph带来的一个好处,有graph就有了之前的网络结构,有了graph可以直接用name取Tensor来给Python变量赋值。

建好了graph,把变量加到collection,就可以restore了,YYY打印出来和W1一样,是同一个Tensor

<tf.Variable 'var1:0' shape=(2, 3) dtype=int32_ref>
[[1 2 3]
 [4 5 6]]
[[11 22 33]
 [44 55 66]]
rm: cannot remove './my_model': No such file or directory
before: []
after: [<tf.Tensor 'var1:0' shape=(2, 3) dtype=int32_ref>, <tf.Tensor 'var2:0' shape=(2, 3) dtype=int32_ref>]
load data
[[1 2 3]
 [4 5 6]]
[[11 22 33]
 [44 55 66]]
[[1 2 3]
 [4 5 6]]
Tensor("var1:0", shape=(2, 3), dtype=int32_ref)
Tensor("var1:0", shape=(2, 3), dtype=int32_ref)

 

关于global_step,不写个完整例子显得有点糊弄了,下面补充一个训练W拟合公式Y=2X并且随意选取节点进行预测的例子:

from tensorflow.python.platform import gfile
import tensorflow as tf
import os

Global_step = tf.Variable(0,name = 'global_step')

X = tf.constant(10.0)
Y = tf.constant(20.0)#Y=2*X
W = tf.Variable(1.1, name = 'weight')
Prediction = W*X
loss = (Y - Prediction)**2
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss,global_step = Global_step)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())

for i in range(100):
  print('Prediction:',sess.run(Prediction))
  print('loss:',sess.run(loss))
  print('W:',sess.run(W))
  sess.run(train_op)
  print(sess.run(Global_step))
  os.system('rm ./tmp2/my_graph.pb')
  tf.train.write_graph(sess.graph_def, './tmp2', 'my_graph.pb',False)

  # os.system('rm ./tmp2/my_model')
  saver.save(sess,'./tmp2/my_model',global_step = Global_step)



##############################################
tf.reset_default_graph()
###############################################
sess = tf.Session()#after reset_default_graph(),a new session is necessary

with gfile.FastGFile("./tmp2/my_graph.pb", 'rb') as f:#"tmp/load/test.pb"
  graph_def = tf.GraphDef()
  graph_def.ParseFromString(f.read())
  sess.graph.as_default()
  tf.import_graph_def(graph_def,name='')

X = tf.placeholder(dtype = tf.float32)
W = sess.graph.get_tensor_by_name("weight:0")#you have to get a graph firstly
Prediction = X*W

print('before:', tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, W)
print('after:', tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

try:
  saver = tf.train.Saver(tf.global_variables())
except:
  pass
print("load data")

saver.restore(sess,'./tmp2/my_model-99')

print(W.eval(session=sess))
print(sess.run(Prediction,{X : 10.0}))
# print(sess.run(W2))


同样在一个文件完成存取,在运行训练操作时传入的global_step会自动被加1,就相当于一个自己传入的counter,我把每一步都用独立后缀存一个文件,然后任意指定读取,后缀-99是第99步的训练结果,这就是global_step的实际应用场景。

Prediction: 19.999998
loss: 3.637979e-12
W: 1.9999998
95
Prediction: 19.999998
loss: 3.637979e-12
W: 1.9999998
96
Prediction: 19.999998
loss: 3.637979e-12
W: 1.9999998
97
Prediction: 19.999998
loss: 3.637979e-12
W: 1.9999998
98
Prediction: 19.999998
loss: 3.637979e-12
W: 1.9999998
99
Prediction: 19.999998
loss: 3.637979e-12
W: 1.9999998
100
before: []
after: [<tf.Tensor 'weight:0' shape=() dtype=float32_ref>]
load data
1.9999998
19.999998

 

此外,还有一些高级方法。比如使用graph_util,graph_util.convert_variables_to_constants可以同时存储graph和variable到pb文件。

展示代码如下:

from tensorflow.python.platform import gfile
import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
from tensorflow.python.tools import inspect_checkpoint as chkp

with tf.variable_scope('net'):
  Global_step = tf.Variable(0,name = 'global_step')

  X = tf.placeholder(dtype = tf.float32, name = 'input')
  Y = tf.constant(20.0)#Y=2*X
  W = tf.Variable(1.1, name = 'weight')
  print(W)
  Prediction = tf.multiply(W,X,name = 'Prediction')
  loss = (Y - Prediction)**2
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(loss,global_step = Global_step)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables())

for i in range(100):
  # print('Prediction:',sess.run(Prediction, {X:10.0}))
  # print('loss:',sess.run(loss, {X:10.0}))
  # print('W:',sess.run(W))
  sess.run(train_op, {X:10.0})

output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        'net/Prediction'.split(","),
)

# Write GraphDef to file if output path has been given.

with gfile.GFile('my_final_graph.pb', "wb") as f:
  f.write(output_graph_def.SerializeToString())

##############################################
tf.reset_default_graph()
###############################################
sess = tf.Session()#after reset_default_graph(),a new session is necessary

with gfile.FastGFile("my_final_graph.pb", 'rb') as f:#"tmp/load/test.pb"
  graph_def = tf.GraphDef()
  graph_def.ParseFromString(f.read())
  sess.graph.as_default()
  tf.import_graph_def(graph_def,{})
# print('sess.graph:',sess.graph.get_operations())

Prediction = sess.graph.get_tensor_by_name("import/net/Prediction:0")#you have to get a graph firstly

# print('before:', tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, Prediction)
# print('after:', tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))

print(sess.run(Prediction,{'import/net/input:0': 10.0}))
print(sess.run(Prediction,{'import/net/input:0': 15.0}))


代码说明:训练之后存graph,手动指定存储Prediction(但是整个网络结构和变量都是有的),然后将Prediction取出来,直接传入'import/net/input:0'就可以进行预测了。

关于'import/'前缀,存的时候确实没有'import/',但是取的时候必须要有,也许怕你导入之后和原生的重名冲突,所以自动加上了。但是!'import/'不是什么万能排他标识,没有任何机制保护,只是普通scope名称!需要注意自己定义的scope名称不要乱用'import/'前缀,否则scope还是乱的,会发生重定义,新定义的变量覆盖读文件的变量,然后告诉你用的是未定义的新变量。

比如下边这种用法,一般不会故意干这种事,只是提一下存在这种可能,暂时不探讨细节。

# with tf.variable_scope('import/net'):
#   X = tf.placeholder(dtype = tf.float32, name = 'input')
#   Y = tf.constant(20.0)#Y=2*X
#   W = tf.Variable(1.1, name = 'weight')
#   print(W)
#   Prediction = tf.multiply(W,X,name = 'Prediction')

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值