是否会有这样的困惑:
-
网络相同,不同类别之间的权重如何加载?
-
网络相同,现在需要添加正则化损失训练,如何加载没有添加正则化损失训练的权重。或者添加正则化损失,现在不需要了,如何加载?
1.首先在网络层需要对比不同类别或者不同损失的op的变化情况:
from tensorflow.python import pywrap_tensorflow
checkpoint_path_rg = '/data/git/ocr-platform/weights/recognize/ocr_densenet_tensorflow/metal/metal_metal/densenet_adderrall__44ckpt____gru_rd_block_rot____rg_16.ckpt'
reader_rg = pywrap_tensorflow.NewCheckpointReader(checkpoint_path_rg) #tf.train.NewCheckpointReader
var_to_shape_map_rg = reader_rg.get_variable_to_shape_map()
with open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/train_op_rg.txt","w") as f:
for key in var_to_shape_map_rg:
print(key)
f.write(key+'\n')
checkpoint_path = '/data/git/ocr-platform/weights/recognize/ocr_densenet_tensorflow/metal/metal_metal/densenet_modify035_197ckpt_bd10-37rot___gru_rd_block___reverse_const__188.ckpt'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReader
var_to_shape_map = reader.get_variable_to_shape_map()
with open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/train_op.txt","w") as f0:
for key in var_to_shape_map:
print(key)
f0.write(key+'\n')
保存下来不同类别(或者不同损失)的两个权重的op文件:
2.接着来对比这两个op文件的不同之处:
f1 = open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/train_op.txt","r") #line1
f2 = open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/train_op_rg.txt","r") #line2
txt1 = f1.read()
txt2 = f2.read()
f1.close()
f2.close()
line1 = txt1.split()
line2 = txt2.split()
line1 = [line[:-2] for line in line1]
line2 = [line[:-2] for line in line2]
# print(line1)
# print(line2)
outfile = open("/data/git/ocr-platform/statistic/recognize/ocr_densenet_tensorflow/preprocess/diff.txt", "w")
for i in line1:
if i not in line2:
outfile.write(i+'\n')
outfile.write("Above content in 1. But not in 2."+'\n')
for j in line2:
if j not in line1:
outfile.write(j+'\n')
outfile.write("Above content in 2. But not in 1.")
print("核对结束")
打开diff.txt文件即可看到不同的op了:
3.将不同的两个op:Variable_1、Variable_1/Momentum放入到exclude中:
exclude=['Variable_1','Variable_1/Momentum']
在加载权重的时候过滤掉:
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
saver = tf.train.Saver()
exclude=[]
if not args.model_path == '':
# saver.restore(sess, args.model_path)
# var = tf.global_variables()
# for val in var :
# print(val.name) #查看所有变量名称
#过滤参与训练但是不能参与训练的变量(由于类别不同或者损失不同)
################ 删除不同类别的op ################
#exclude=['sequence_rnn_module/w']
################ 删除不同损失的op ################
exclude=['Variable_1','Variable_1/Momentum']
variables_to_restore = tf.contrib.slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, args.model_path)
注意!!:保存权重的时候记得重新定义saver=tf.train.Saver(),否则会使用之前删除变量权重的保存对象saver,这样保存下来的权重是不完整的,预测的时候无法加载。
if len(exclude)!=0:
############重新定义##########
saver = tf.train.Saver()
saver.save(sess, ckpt_path+"_%d.ckpt"%(epoch), write_meta_graph=True)
else:
saver.save(sess, ckpt_path+"_%d.ckpt"%(epoch), write_meta_graph=True)