Tensorflow-变量保存与导入

参考:


1、基本用法

#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np


logdir='./output/'

with tf.variable_scope('conv'):
    w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
    b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
sess=tf.InteractiveSession()
saver=tf.train.Saver([w]) # 参数为空,默认保存所有变量,此处只保存变量w
tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
tf.variables_initializer([b]) # 初始化变量b

saver.save(sess,logdir+'model.ckpt')
print('w',w.eval())
print('-----------')
print('b',b.eval())
sess.close()

2、变量从ckpt中提取,没有的需初始化

  • 第一运行
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse

logdir='./output/'

with tf.variable_scope('conv'):
    w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
    b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
# with tf.variable_scope('conv2'):
#     w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
#     b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)

sess=tf.InteractiveSession()
tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
    try:
        saver = tf.train.Saver()  # 参数为空,默认保存所有变量,这里只有变量w1、b1
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver=None
    except:
        saver = tf.train.Saver([w1,b1])  # 参数为空,默认保存所有变量,这里只有变量w1、b1
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver = None
# tf.variables_initializer([b1]) # 初始化变量b

saver=tf.train.Saver() # 参数为空,默认保存所有变量,这里只有变量w1、b1
saver.save(sess,logdir+'model.ckpt')
print('w',w1.eval())
print('-----------')
print('b',b1.eval())
print('-----------')
# print('w',w2.eval())
# print('-----------')
# print('b',b2.eval())

sess.close()

这里写图片描述

  • 第二次运行
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse

logdir='./output/'

with tf.variable_scope('conv'):
    w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
    b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
with tf.variable_scope('conv2'):
    w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
    b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)

sess=tf.InteractiveSession()

tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
    try:
        saver = tf.train.Saver()  # 参数为空,默认提取所有变量,
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver = None
    except:
        saver = tf.train.Saver([w1, b1])  # 参数为空,默认提取所有变量,
        # 此处提取变量w1、b1(因为上步保存的变量没有w2,b2,如果使用saver = tf.train.Saver()会报错)
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver=None
# tf.variables_initializer([b1]) # 初始化变量b

saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2
saver.save(sess,logdir+'model.ckpt')
print('w',w1.eval())
print('-----------')
print('b',b1.eval())
print('-----------')
print('w',w2.eval())
print('-----------')
print('b',b2.eval())

sess.close()

这里写图片描述

  • 第3次运行
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse

logdir='./output/'

with tf.variable_scope('conv'):
    w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
    b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
with tf.variable_scope('conv2'):
    w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
    b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)

sess=tf.InteractiveSession()

tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
    # try:
    #     saver = tf.train.Saver()  # 参数为空,默认提取所有变量,
    #     saver.restore(sess, ckpt.model_checkpoint_path)
    #     saver = None
    # except:
    saver = tf.train.Saver([w1, b1])  # 上一步保存的变量有w1,b1,w2,b2,这里只提取w1,b1
    saver.restore(sess, ckpt.model_checkpoint_path)
    saver=None
# tf.variables_initializer([w2,b2]) # 初始化变量w2,b2

saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2
saver.save(sess,logdir+'model.ckpt')
print('w',w1.eval())
print('-----------')
print('b',b1.eval())
print('-----------')
print('w',w2.eval())
print('-----------')
print('b',b2.eval())

sess.close()

这里写图片描述

3、总结

  • 保存变量
tf.global_variables_initializer().run() # 初始化所有变量
saver=tf.train.Saver() # 参数为空,默认保存所有变量
saver=tf.train.Saver([w,b]) # 保存部分变量
saver.save(sess,logdir+'model.ckpt')
  • 导入变量
tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
    try:
        saver = tf.train.Saver()  # 参数为空,默认导入所有变量,
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver = None
    except:
        saver = tf.train.Saver([w1, b1])  # 导入部分变量
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver=None

如果保存的变量有w1,b1,w2,b2,但只导入w1,b1,对w2,b2重新初始化,训练等 使用

saver = tf.train.Saver([w1, b1])

如果保存的变量中只有w1,b1,现在新增变量w2,b2 则只能导入w1,b1

saver = tf.train.Saver([w1, b1])

通过这种方法可以实现,只导入模型的前n-1层参数,而对第n层参数重新初始化训练,这样就能很好的实现迁移学习

4、补充 结合variable_scope

#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse

logdir='./output/'

with tf.variable_scope('conv'):
    w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
    b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
# with tf.variable_scope('conv2'):
#     w=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
#     b=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)

sess=tf.InteractiveSession()

tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
    # try:
    #     saver = tf.train.Saver()  # 参数为空,默认提取所有变量,
    #     saver.restore(sess, ckpt.model_checkpoint_path)
    #     saver = None
    # except:
        saver = tf.train.Saver([tf.variable_op_scope(w,name_or_scope='conv/w:0').args[0],
                                tf.variable_op_scope(b, name_or_scope='conv/b:0').args[0]])  # 上一步保存的变量有w1,b1,w2,b2,这里只提取w1,b1
        saver.restore(sess, ckpt.model_checkpoint_path)
        saver=None
# tf.variables_initializer([w2,b2]) # 初始化变量w2,b2

saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2
saver.save(sess,logdir+'model.ckpt')
print('w',w.eval())
print('-----------')
print('b',b.eval())
print('-----------')
# print('w',w2.eval())
# print('-----------')
# print('b',b2.eval())

sess.close()

说明:

with tf.variable_scope('conv'):
    w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
    b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
with tf.variable_scope('conv2'):
    w=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
    b=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)
print('w',w.eval()) # 打印的是'conv2' 中的w
print('-----------')
print('b',b.eval())# 打印的是'conv2' 中的b

如果要打印的是’conv1’ 中的w,b

print('w',tf.variable_op_scope(w,name_or_scope='conv/w:0').args[0].eval())
print('-----------')
print('b',tf.variable_op_scope(b, name_or_scope='conv/b:0').args[0].eval())

打印’conv2’ 中的w,b也可以使用

print('w',tf.variable_op_scope(w,name_or_scope='conv2/w:0').args[0].eval())
print('-----------')
print('b',tf.variable_op_scope(b, name_or_scope='conv2/b:0').args[0].eval())
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值