import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
def create_model():
with tf.variable_scope("wwj"):
input_data = tf.get_variable(name="input_data", shape=(3, 24, 24, 3), dtype=tf.float32)
w_1 = tf.get_variable(name="w_1", shape=(3, 3, 3, 32), dtype=tf.float32)
con_1 = tf.nn.conv2d(input=input_data, filter=w_1, strides=(1, 1, 1, 1), padding="SAME")
return con_1
def save_model():
con_1 = create_model()
saver = tf.train.Saver()
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(con_1)
saver.save(sess, "./checkpoints/text.ckpt")
def get_model():
con_1 = create_model()
saver = tf.train.Saver()
sess = tf.InteractiveSession()
# tf.trainable_variables返回的是需要训练的变量列表
# tf.all_variables返回的是所有变量的列表
sess.run(tf.global_variables_initializer())
sess.run(con_1)
saver.restore(sess, "./checkpoints/text.ckpt")
for v in tf.trainable_variables():
var_name = str(v.name)
print(var_name, sess.run(v).shape)
# 直接读取
path = "./checkpoints/ssd_300_vgg.ckpt"
reader = pywrap_tensorflow.NewCheckpointReader(path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print(key, reader.get_tensor(key).shape)
tensorflow存取模型
最新推荐文章于 2023-06-24 14:19:19 发布