1 模型保存
【Demo】
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 占位符,供载入模型时赋值
x1 = tf.placeholder(tf.int32, shape=[1], name="x")
y1 = tf.placeholder(tf.int32, shape=[1], name="y")
# 计算图保存的实际变量
v1 = tf.Variable(tf.constant(125, shape=[1]), name="v_1")
v2 = tf.Variable(tf.constant(125, shape=[1]), name="v_2")
global_step = tf.Variable(0, trainable=False, name="global_step")
result = v1 + v2
# 保存模型类
saver = tf.train.Saver()
with tf.Session() as sess:
# 初始化变量
init_op = tf.global_variables_initializer()
# 运行初始化
sess.run(init_op)
'''载入已保存的模型'''
if os.path.isfile("./test_model/plus_model.ckpt.meta"):
ckpt = tf.train.get_checkpoint_state("./test_model")
model_path = ckpt.model_checkpoint_path
saver.restore(sess, model_path)
# 保存模型
saver.save(sess, "./test_model/plus_model.ckpt")
# 保存计算图的元数据结构
saver.export_meta_graph("./test_model/plus_model.ckpt.meta.json", as_text=True)
【模型结构】
序号 | 文件 | 描述 |
---|---|---|
1 | checkpoint | 文本文件,记录最新的模型文件列表 |
2 | .data | 包含训练变量的值value |
3 | .index | 包含.data和.meta文件对应关系 |
4 | .meta | 包含网络图结构,如GraphDef,SaverDef |
2 模型载入
2.1 只载入模型参数忽略图结构
【Demo1】
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 定义图结构,由于未载入模型的图结构,所以需要新建图结构
# 图结构要保证变量与源模型变量名成一致
v1 = tf.Variable([0], name="v_1")
v2 = tf.Variable([0], name="v_2")
result = v1 + v2
with tf.Session() as sess:
# 保存变量
saver = tf.train.Saver()
'''
get_checkpoint_state检查模型更新节点,获取节点信息。
模型中的checkpoint文件存储节点信息。
'''
ckpt = tf.train.get_checkpoint_state("./test_model")
model_path = ckpt.model_checkpoint_path
'''
restore中的路径参数即为ckpt文件.
若没有checkpoint文件,需要手动添加ckpt到restore中。
'''
saver.restore(sess, model_path)
print("Model path: {}".format(model_path))
print("Result: {}".format(sess.run(result)))
【Result】
Model path: ./test_model/plus_model.ckpt
Result: [250]
【Analysis】
(1) 只载入模型的参数,而不使用模型图结构,需要重新定义图结构,但是要保证新建变量与模型变量的name
一致,定义的节点可以少于模型原有节点,并且只能使用定义的节点,情况:即使模型中有v_1
节点,但是,在新建图是没有定义,仍就不能使用v_1
节点,因为该模式下,只载入模型参数,未获取模型的图结构,因此定义多少,使用多少;
(2) 新建的图结构,需要有可保存的变量,即使用tf.Variable
定义的节点,并且需要初始化,即赋初值,否则save
或restore
会No variable available
报错;
(3) 模型保存的参数值是固定的,所以,新建图节点初始值值不会影响载入模型的计算值;
【Demo2】
载入神经网络参数。
g_params = tf.Graph()
with g_params.as_default():
x = tf.placeholder(tf.float32, [None, INPUT_NODE], name="x_input")
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y_input')
regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
y = inference(x, regularizer)
def load_model_only_with_params(images_num):
mnist = input_data.read_data_sets("./mnist_data", one_hot=True)
images_data = [mnist.test.images[i] for i in range(images_num)]
images_label = [mnist.test.labels[i] for i in range(images_num)]
predict_labels = []
train_labels = []
with tf.Session(graph=g_params) as sess:
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state("./new_models")
model_path = ckpt.model_checkpoint_path
saver.restore(sess, model_path)
for i in range(images_num):
train_label = tf.expand_dims(images_label[i], [0])
train_label = tf.argmax(train_label, 1)
train_label = sess.run(train_label)
'''Extract data from list such as [7].'''
train_labels.append(train_label[0])
images = tf.expand_dims(images_data[i], [0])
images = sess.run(images)
# print("images shape: {}".format(images.shape))
pre = sess.run(y, feed_dict={x: images})
pre_value = pre[0]
print("Extract value from predicted result: {}".format(pre_value))
sum_pre = sum(pre_value.tolist())
print("sum predicted value: {}".format(sum_pre))
'''Get value coresponding number.'''
pre_num = tf.argmax(pre, 1)
pre_num = sess.run(pre_num)
predict_labels.append(pre_num[0])
print("train data labels: {}".format(train_labels))
print("predicted number: {}".format(predict_labels))
conf_mx = confusion_matrix(train_labels, predict_labels)
print("confusion matrixs: \n {}".format(conf_mx))
if not os.path.exists("./images"):
os.makedirs("./images")
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.title("训练数据和预测数据的混淆矩阵", fontproperties=font)
plt.savefig("./images/confusion_matrix.png", format="png")
plt.show()
2.2 载入模型参数和图结构
【Demo】
import tensorflow as tf
tf.reset_default_graph()
with tf.Session() as sess:
# 载入模型的图结构
saver = tf.train.import_meta_graph("test_model/plus_model.ckpt.meta")
# 载入模型参数
model_file = tf.train.latest_checkpoint("test_model/")
saver.restore(sess, model_file)
# 获取默认图
g = tf.get_default_graph()
print("Result: {}".format(sess.run(g.get_tensor_by_name("add:0"))))
print("Result: {}".format(sess.run(g.get_tensor_by_name("v_1:0"))))
print("Result: {}".format(sess.run(g.get_tensor_by_name("v_2:0"))))
【Result】
Result: [250]
Result: [125]
Result: [125]
【Analysis】
(1) 直接模型的全部内容,包括模型结构和模型参数,因此不需要重新定义图结构;
(2) 获取张量值
,需要先获取默认的图结构
,在图结构中,通过张量名称获张量值;
3 模型结构分析
3.1 载入参数忽略图结构
【Demo】
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 新建图结构
g = tf.Graph()
# 获取默认图结构
with g.as_default():
x = tf.placeholder(tf.int32, name="x")
v1 = tf.Variable(tf.constant(0, shape=[1]), name="v_1")
v2 = tf.Variable(tf.constant(0, shape=[1]), name="v_2")
global_step = tf.Variable(0, trainable=False, name="global_step")
result = v1 + v2
with tf.Session(graph=g) as sess:
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state("./test_model")
model_path = ckpt.model_checkpoint_path
print("Model path: {}".format(model_path))
saver.restore(sess, model_path)
# 获取模型中tf.Variable定义的变量
trainable_variable_in_model = tf.trainable_variables()
print("Trainable variables in model: {}".format(variable_in_model))
# 获取全局变量
g_variable = tf.global_variables()
print("Global variables in model: {}".format(g_variable))
# 获取tf.Vaiable定义变量的变量名称
variable_in_model = [v.name for v in tf.trainable_variables()]
print("Trainable variables name in modle: {}".format(variable_in_model))
# 上下文管理器:获取默认图
g = tf.get_default_graph()
# 获取模型中的:操作
operations_in_model = g.get_operations()
print("Operations in model: {}".format(operations_in_model))
# 通过操作名称获取操作
name_operation_placeholder = g.get_operation_by_name('x')
name_operation_outputs_placeholder = g.get_operation_by_name('x').outputs
name_operation_outputs_value_placeholder = g.get_operation_by_name('x').outputs[0]
print("operation name: {}".format(name_operation_placeholder))
print("operation outputs: {}".format(name_operation_outputs_placeholder))
print("operation outputs value: {}".format(name_operation_outputs_value_placeholder))
name_operation = g.get_operation_by_name('v_1')
name_operation_outputs = g.get_operation_by_name('v_1').outputs
name_operation_outputs_value = g.get_operation_by_name('v_1').outputs[0]
print("operation name: {}".format(name_operation))
print("operation outputs: {}".format(name_operation_outputs))
print("operation outputs value: {}".format(name_operation_outputs_value))
g = tf.get_default_graph()
print("Result: {}".format(sess.run(g.get_tensor_by_name("add:0"))))
print("Result: {}".format(sess.run(g.get_tensor_by_name("v_1:0"))))
print("Result: {}".format(sess.run(g.get_tensor_by_name("v_2:0"))))
【Result】
Model path: ./test_model/plus_model.ckpt
INFO:tensorflow:Restoring parameters from ./test_model/plus_model.ckpt
Trainable variables in model: [<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>]
Global variables in model: [<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'global_step:0' shape=() dtype=int32_ref>]
Trainable variables name in modle: ['v_1:0', 'v_2:0']
Operations in model: [<tf.Operation 'x' type=Placeholder>, <tf.Operation 'Const' type=Const>, <tf.Operation 'v_1' type=VariableV2>, <tf.Operation 'v_1/Assign' type=Assign>, <tf.Operation 'v_1/read' type=Identity>, <tf.Operation 'Const_1' type=Const>, <tf.Operation 'v_2' type=VariableV2>, <tf.Operation 'v_2/Assign' type=Assign>, <tf.Operation 'v_2/read' type=Identity>, <tf.Operation 'global_step/initial_value' type=Const>, <tf.Operation 'global_step' type=VariableV2>, <tf.Operation 'global_step/Assign' type=Assign>, <tf.Operation 'global_step/read' type=Identity>, <tf.Operation 'add' type=Add>, <tf.Operation 'save/Const' type=Const>, <tf.Operation 'save/SaveV2/tensor_names' type=Const>, <tf.Operation 'save/SaveV2/shape_and_slices' type=Const>, <tf.Operation 'save/SaveV2' type=SaveV2>, <tf.Operation 'save/control_dependency' type=Identity>, <tf.Operation 'save/RestoreV2/tensor_names' type=Const>, <tf.Operation 'save/RestoreV2/shape_and_slices' type=Const>, <tf.Operation 'save/RestoreV2' type=RestoreV2>, <tf.Operation 'save/Assign' type=Assign>, <tf.Operation 'save/Assign_1' type=Assign>, <tf.Operation 'save/Assign_2' type=Assign>, <tf.Operation 'save/restore_all' type=NoOp>]
operation name: name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
operation outputs: [<tf.Tensor 'x:0' shape=<unknown> dtype=int32>]
operation outputs value: Tensor("x:0", dtype=int32)
operation name: name: "v_1"
op: "VariableV2"
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
operation outputs: [<tf.Tensor 'v_1:0' shape=(1,) dtype=int32_ref>]
operation outputs value: Tensor("v_1:0", shape=(1,), dtype=int32_ref)
Result: [250]
Result: [125]
Result: [125]
【Analysis】
(1) 使用模型时,需要先载入模型;
(2) 通过tf.trainable_variables
获取模型的Variable定义的可用于训练的变量,[<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>]
,新的图中定义了4
个变量,可用于训练的只有2个,即v_1
和v_2
,虽然global_step
是Variable定义的,但是trainable=False
声明该变量不能用于训练;
(3) 通过tf.global_variables
获取全部tf.Variable
定义的变量(节点),结果[<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'global_step:0' shape=() dtype=int32_ref>]
,共有三个变量,v_1
,v_2
和global_step
;
(4) 通过tf.get_operations()
获取模型的所有操作,即新图定义的所有节点的操作,新图定义的节点有x,v_1, v_2和global_step,结果:[<tf.Operation 'x' type=Placeholder>, <tf.Operation 'Const' type=Const>, <tf.Operation 'v_1' type=VariableV2>, <tf.Operation 'v_1/Assign' type=Assign>, <tf.Operation 'v_1/read' type=Identity>, <tf.Operation 'Const_1' type=Const>, <tf.Operation 'v_2' type=VariableV2>, <tf.Operation 'v_2/Assign' type=Assign>, <tf.Operation 'v_2/read' type=Identity>, <tf.Operation 'global_step/initial_value' type=Const>, <tf.Operation 'global_step' type=VariableV2>, <tf.Operation 'global_step/Assign' type=Assign>, <tf.Operation 'global_step/read' type=Identity>, <tf.Operation 'add' type=Add>, <tf.Operation 'save/Const' type=Const>, <tf.Operation 'save/SaveV2/tensor_names' type=Const>, <tf.Operation 'save/SaveV2/shape_and_slices' type=Const>, <tf.Operation 'save/SaveV2' type=SaveV2>, <tf.Operation 'save/control_dependency' type=Identity>, <tf.Operation 'save/RestoreV2/tensor_names' type=Const>, <tf.Operation 'save/RestoreV2/shape_and_slices' type=Const>, <tf.Operation 'save/RestoreV2' type=RestoreV2>, <tf.Operation 'save/Assign' type=Assign>, <tf.Operation 'save/Assign_1' type=Assign>, <tf.Operation 'save/Assign_2' type=Assign>, <tf.Operation 'save/restore_all' type=NoOp>]
,其中,变量类型的节点(操作)有Placeholder,VariableV2,和Add,分别对应变量x(Placeholder),v_1(VariableV2),v_2(VariableV2),global_step(VariableV2)因为在Variable中嵌入了const所以,有Contst节点;而模型中总共定义了有四个节点(张量),返回只有三个,少了y(Placeholder),说明了该方法只载入模型参数,未导入图结构,新建的图结构有多少,就会返回多少张量操作;
(4) 通过tf.get_default_graph.get_operation_by_name('x')
获取节点x
的图结构;
name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
unknown_rank: true
}
}
}
(5) 通过tf.get_default_graph.get_operation_by_name('x').outputs
获取x节点的张量列表,结果: [<tf.Tensor 'x:0' shape=<unknown> dtype=int32>]
(6) 通过tf.get_default_graph.get_operation_by_name('x').outputs[0]
获取节点x的张量Tensor("x:0", dtype=int32)
,通过该张量可进行赋值;如下解析:
【Demo】
import tensorflow as tf
x = tf.placeholder(tf.int32, name="x")
print("Node x placeholer: {}".format(x))
【Result】
Node x placeholer: Tensor("x:0", dtype=int32)
(7) 通过tf.get_default_graph.get_tensor_by_name("add:0")
获取模型中的张量值,如:Result: [250]
3.2 载入模型的参数和图结构
【Demo】
import tensorflow as tf
tf.reset_default_graph()
with tf.Session() as sess:
saver = tf.train.import_meta_graph("test_model/plus_model.ckpt.meta")
model_file = tf.train.latest_checkpoint("test_model/")
saver.restore(sess, model_file)
# 获取模型中tf.Variable定义的变量
variable_in_model = tf.trainable_variables()
print("Tensor in model: {}".format(variable_in_model))
# 获取全局变量
g_variable = tf.global_variables()
print("Global variables in model: {}".format(g_variable))
# 获取tf.Vaiable定义变量的变量名称
variable_in_model = [v.name for v in tf.trainable_variables()]
print("Trainable variables name in modle: {}".format(variable_in_model))
# 上下文管理器:获取默认图
g = tf.get_default_graph()
# 获取模型中的:操作
operations_in_model = g.get_operations()
# print("Operations in model: {}".format(operations_in_model))
# 通过操作名称获取操作
name_operation_placeholder = g.get_operation_by_name('x')
name_operation_outputs_placeholder = g.get_operation_by_name('x').outputs
name_operation_outputs_value_placeholder = g.get_operation_by_name('x').outputs[0]
print("operation name: {}".format(name_operation_placeholder))
print("operation name: {}".format(name_operation_outputs_placeholder))
print("operation name: {}".format(name_operation_outputs_value_placeholder))
name_operation = g.get_operation_by_name('v_1')
name_operation_outputs = g.get_operation_by_name('v_1').outputs
name_operation_outputs_value = g.get_operation_by_name('v_1').outputs[0]
print("operation name: {}".format(name_operation))
print("operation name: {}".format(name_operation_outputs))
print("operation name: {}".format(name_operation_outputs_value))
g = tf.get_default_graph()
print("Result: {}".format(sess.run(g.get_tensor_by_name("add:0"))))
print("Result: {}".format(sess.run(g.get_tensor_by_name("v_1:0"))))
print("Result: {}".format(sess.run(g.get_tensor_by_name("v_2:0"))))
【Result】
INFO:tensorflow:Restoring parameters from test_model/plus_model.ckpt
Tensor in model: [<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>]
Global variables in model: [<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'global_step:0' shape=() dtype=int32_ref>]
Trainable variables name in modle: ['v_1:0', 'v_2:0']
operation name: name: "x"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
operation name: [<tf.Tensor 'x:0' shape=(1,) dtype=int32>]
operation name: Tensor("x:0", shape=(1,), dtype=int32)
operation name: name: "v_1"
op: "VariableV2"
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
operation name: [<tf.Tensor 'v_1:0' shape=(1,) dtype=int32_ref>]
operation name: Tensor("v_1:0", shape=(1,), dtype=int32_ref)
Result: [250]
Result: [125]
Result: [125]
【Analysis】
(1) 与3.1
节中不同的是,该模式下,会载入模型的全部参数和图结构,因此,使用tf.get_operations()
会获取所有节点的操作;
(2) 其他与3.1
节相同;
4 模型预测准确度
【Demo】
import tensorflow as tf
tf.reset_default_graph()
pre = np.array([[0, 1, 3], [0, 3, 2], [1, 0, -1], [10, 1, 3]])
label = np.array([2, 1, 0, 0])
pre_t = tf.convert_to_tensor(pre, tf.float32)
label_t = tf.convert_to_tensor(label, tf.int32)
correct = tf.nn.in_top_k(pre_t, label_t, 1)
correct_num = tf.cast(correct, tf.float32)
accuracy = tf.reduce_mean(correct_num)
with tf.Session() as sess:
correct, correct_num, accuracy = sess.run([correct, correct_num, accuracy])
print(correct)
print(correct_num)
print(accuracy)
【Result】
[ True True True True]
[1. 1. 1. 1.]
1.0
5 总结
(1) get_chieckpoint_state检查模型的checkpoint文件,获取模型节点信息,即ckpt文件。
(2) restore中的路径信息为模型中ckpt文件的路径,若模型文件中没有checkpoint文件,需要手动添加ckpt文件路径到restore中。
(3) 模型载入有两种模式,仅载入模型参数,载入模型参数和模型结构。
(4) 模型准确度即所有预测结果中预测正确的数量比例,预测值与实际值符合的平均值。
[参考文献]
[1]https://tensorflow.google.cn/guide/saved_model
[2]https://www.cnblogs.com/hellcat/p/6925757.html