保存
conv1_weights=tf.get_variable('conv1_weights',[CONV1_SIZE,CONV1_SIZE,NUM_CHANNELS,CONV1_DEEP],initializer=tf.truncated_normal_initializer(stddev=0.1))
MODEL_SAVE_PATH='./tensorflow_model'
MODEL_NAME='model_test.ckpt'
saver=tf.train.Saver()
saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)
读取模型文件中的参数,需要注意的一点就是MODEL_NAME后面需要加上要读取的多少步的模型,
conv1_weights.eval()生成的是类型是numpy数组,用numpy.save()就可以保存参数了
import numpy as np
import tensorflow as tf
import os
tf.reset_default_graph()
NUM_CHANNELS=1
CONV1_DEEP=16
CONV1_SIZE=3
#名字要与训练的时候用的名字一致
conv1_weights=tf.get_variable('conv1_weights',[CONV1_SIZE,CONV1_SIZE,NUM_CHANNELS,CONV1_DEEP])
saver = tf.train.Saver()
MODEL_SAVE_PATH='./tensorflow_model'
MODEL_NAME='model_test.ckpt-1501'
with tf.Session() as sess:
saver.restore(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME))
conv1_w = conv1_weights.eval()
# print ("weights= ", conv1_weights.eval())
np.save('./params/conv1_w.npy',conv1_w)