Tensorflow-加载权重过滤不需要的变量(选择性加载)

 是否会有这样的困惑:

  • 网络相同,不同类别之间的权重如何加载?

  • 网络相同,现在需要添加正则化损失训练,如何加载没有添加正则化损失训练的权重。或者添加正则化损失,现在不需要了,如何加载?

1.首先在网络层需要对比不同类别或者不同损失的op的变化情况:

from tensorflow.python import pywrap_tensorflow
checkpoint_path_rg = '/data/git/ocr-platform/weights/recognize/ocr_densenet_tensorflow/metal/metal_metal/densenet_adderrall__44ckpt____gru_rd_block_rot____rg_16.ckpt'
reader_rg = pywrap_tensorflow.NewCheckpointReader(checkpoint_path_rg) #tf.train.NewCheckpointReader
var_to_shape_map_rg = reader_rg.get_variable_to_shape_map()
with open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/train_op_rg.txt","w") as f:
        for key in var_to_shape_map_rg:
            print(key)
            f.write(key+'\n')

checkpoint_path = '/data/git/ocr-platform/weights/recognize/ocr_densenet_tensorflow/metal/metal_metal/densenet_modify035_197ckpt_bd10-37rot___gru_rd_block___reverse_const__188.ckpt'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReader
var_to_shape_map = reader.get_variable_to_shape_map()
with open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/train_op.txt","w") as f0:
        for key in var_to_shape_map:
            print(key)
            f0.write(key+'\n')

 保存下来不同类别(或者不同损失)的两个权重的op文件: 

 

2.接着来对比这两个op文件的不同之处:

f1 = open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/train_op.txt","r") #line1
    f2 = open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/train_op_rg.txt","r") #line2
    txt1 = f1.read()
    txt2 = f2.read()
    f1.close()
    f2.close()
    line1 = txt1.split()
    line2 = txt2.split()
    line1 = [line[:-2] for line in line1]
    line2 = [line[:-2] for line in line2]
    # print(line1)
    # print(line2)
    outfile = open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/diff.txt", "w")
    for i in line1:
        if i not in line2:
            outfile.write(i+'\n')
    outfile.write("Above content in 1. But not in 2."+'\n')
    for j in line2:
        if j not in line1:
            outfile.write(j+'\n')
    outfile.write("Above content in 2. But not in 1.")
    print("核对结束")

打开diff.txt文件即可看到不同的op了:

3.将不同的两个op:Variable_1、Variable_1/Momentum放入到exclude中:

exclude=['Variable_1','Variable_1/Momentum'] 

在加载权重的时候过滤掉: 

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
saver = tf.train.Saver()
exclude=[]
if not args.model_path == '':
   # saver.restore(sess, args.model_path)
   # var = tf.global_variables()
   # for val in var :
   #     print(val.name) #查看所有变量名称
   
   #过滤参与训练但是不能参与训练的变量(由于类别不同或者损失不同)
   ################ 删除不同类别的op ################
   #exclude=['sequence_rnn_module/w'] 
   ################ 删除不同损失的op ################
   exclude=['Variable_1','Variable_1/Momentum'] 

   variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
   saver = tf.train.Saver(variables_to_restore)
   saver.restore(sess, args.model_path)

注意!!:保存权重的时候记得重新定义saver=tf.train.Saver(),否则会使用之前删除变量权重的保存对象saver,这样保存下来的权重是不完整的,预测的时候无法加载。

if len(exclude)!=0:
############重新定义##########
    saver = tf.train.Saver()
    saver.save(sess, ckpt_path+"_%d.ckpt"%(epoch), write_meta_graph=True)
else:
    saver.save(sess, ckpt_path+"_%d.ckpt"%(epoch), write_meta_graph=True)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值