Tensorflow基础知识:模型保存与载入深度解析(二)

1 模型保存

【Demo】

import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 占位符,供载入模型时赋值
x1 = tf.placeholder(tf.int32, shape=[1], name="x")
y1 = tf.placeholder(tf.int32, shape=[1], name="y")
# 计算图保存的实际变量
v1 = tf.Variable(tf.constant(125, shape=[1]), name="v_1")
v2 = tf.Variable(tf.constant(125, shape=[1]), name="v_2")
global_step = tf.Variable(0, trainable=False, name="global_step")
result = v1 + v2
# 保存模型类
saver = tf.train.Saver()
with tf.Session() as sess:
	# 初始化变量
    init_op = tf.global_variables_initializer()
    # 运行初始化
    sess.run(init_op)
    '''载入已保存的模型'''
    if os.path.isfile("./test_model/plus_model.ckpt.meta"):
		ckpt = tf.train.get_checkpoint_state("./test_model")
		model_path = ckpt.model_checkpoint_path
		saver.restore(sess, model_path)
    # 保存模型
    saver.save(sess, "./test_model/plus_model.ckpt")
    # 保存计算图的元数据结构
    saver.export_meta_graph("./test_model/plus_model.ckpt.meta.json", as_text=True)

【模型结构】

序号文件描述
1checkpoint文本文件,记录最新的模型文件列表
2.data包含训练变量的值value
3.index包含.data和.meta文件对应关系
4.meta包含网络图结构,如GraphDef,SaverDef

2 模型载入

2.1 只载入模型参数忽略图结构

【Demo1】

import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 定义图结构,由于未载入模型的图结构,所以需要新建图结构
# 图结构要保证变量与源模型变量名成一致
v1 = tf.Variable([0], name="v_1")
v2 = tf.Variable([0], name="v_2")
result = v1 + v2

with tf.Session() as sess:
	# 保存变量
    saver = tf.train.Saver()
    '''
    get_checkpoint_state检查模型更新节点,获取节点信息。
    模型中的checkpoint文件存储节点信息。
    '''
    ckpt = tf.train.get_checkpoint_state("./test_model")
    model_path = ckpt.model_checkpoint_path
    '''
    restore中的路径参数即为ckpt文件.
    若没有checkpoint文件,需要手动添加ckpt到restore中。
    '''
    saver.restore(sess, model_path)
    print("Model path: {}".format(model_path))
    print("Result: {}".format(sess.run(result)))

【Result】

Model path: ./test_model/plus_model.ckpt
Result: [250]

【Analysis】
(1) 只载入模型的参数,而不使用模型图结构,需要重新定义图结构,但是要保证新建变量与模型变量的name一致,定义的节点可以少于模型原有节点,并且只能使用定义的节点,情况:即使模型中有v_1节点,但是,在新建图是没有定义,仍就不能使用v_1节点,因为该模式下,只载入模型参数,未获取模型的图结构,因此定义多少,使用多少;
(2) 新建的图结构,需要有可保存的变量,即使用tf.Variable定义的节点,并且需要初始化,即赋初值,否则saverestoreNo variable available报错;
(3) 模型保存的参数值是固定的,所以,新建图节点初始值值不会影响载入模型的计算值;
【Demo2】
载入神经网络参数。

g_params = tf.Graph()
with g_params.as_default():
	x = tf.placeholder(tf.float32, [None, INPUT_NODE], name="x_input")
	y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y_input')
	regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE)
	y = inference(x, regularizer)

def load_model_only_with_params(images_num):
	mnist = input_data.read_data_sets("./mnist_data", one_hot=True)
	images_data = [mnist.test.images[i] for i in range(images_num)]
	images_label = [mnist.test.labels[i] for i in range(images_num)]
	predict_labels = []
	train_labels = []
	with tf.Session(graph=g_params) as sess:
		saver = tf.train.Saver()
		ckpt = tf.train.get_checkpoint_state("./new_models")
		model_path = ckpt.model_checkpoint_path
		saver.restore(sess, model_path)
		for i in range(images_num):
			train_label = tf.expand_dims(images_label[i], [0])
			train_label = tf.argmax(train_label, 1)
			train_label = sess.run(train_label)
			'''Extract data from list such as [7].'''
			train_labels.append(train_label[0])
			images = tf.expand_dims(images_data[i], [0])
			images = sess.run(images)
#             print("images shape: {}".format(images.shape))
			pre = sess.run(y, feed_dict={x: images})
			pre_value = pre[0]
			print("Extract value from predicted result: {}".format(pre_value))
			sum_pre = sum(pre_value.tolist())
			print("sum predicted value: {}".format(sum_pre))
			'''Get value coresponding number.'''
			pre_num = tf.argmax(pre, 1)
			pre_num = sess.run(pre_num)
			predict_labels.append(pre_num[0])
		print("train data labels: {}".format(train_labels))
		print("predicted number: {}".format(predict_labels))
		conf_mx = confusion_matrix(train_labels, predict_labels)
		print("confusion matrixs: \n {}".format(conf_mx))
		if not os.path.exists("./images"):
			os.makedirs("./images")
		plt.matshow(conf_mx, cmap=plt.cm.gray)
		plt.title("训练数据和预测数据的混淆矩阵", fontproperties=font)
		plt.savefig("./images/confusion_matrix.png", format="png")
		plt.show()

2.2 载入模型参数和图结构

【Demo】

import tensorflow as tf
tf.reset_default_graph()

with tf.Session() as sess:
	# 载入模型的图结构
    saver = tf.train.import_meta_graph("test_model/plus_model.ckpt.meta")
    # 载入模型参数
    model_file = tf.train.latest_checkpoint("test_model/")
    saver.restore(sess, model_file)
    # 获取默认图
    g = tf.get_default_graph()
    print("Result: {}".format(sess.run(g.get_tensor_by_name("add:0"))))
    print("Result: {}".format(sess.run(g.get_tensor_by_name("v_1:0"))))
    print("Result: {}".format(sess.run(g.get_tensor_by_name("v_2:0"))))

【Result】

Result: [250]
Result: [125]
Result: [125]

【Analysis】
(1) 直接模型的全部内容,包括模型结构和模型参数,因此不需要重新定义图结构;
(2) 获取张量值,需要先获取默认的图结构,在图结构中,通过张量名称获张量值;

3 模型结构分析

3.1 载入参数忽略图结构

【Demo】

import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 新建图结构
g = tf.Graph()
# 获取默认图结构
with g.as_default():
    x = tf.placeholder(tf.int32, name="x")
    v1 = tf.Variable(tf.constant(0, shape=[1]), name="v_1")
    v2 = tf.Variable(tf.constant(0, shape=[1]), name="v_2")
    global_step = tf.Variable(0, trainable=False, name="global_step")
    result = v1 + v2
with tf.Session(graph=g) as sess:
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state("./test_model")
    model_path = ckpt.model_checkpoint_path
    print("Model path: {}".format(model_path))
    saver.restore(sess, model_path)
    # 获取模型中tf.Variable定义的变量
    trainable_variable_in_model = tf.trainable_variables()
    print("Trainable variables in model: {}".format(variable_in_model))
    # 获取全局变量
    g_variable = tf.global_variables()
    print("Global variables in model: {}".format(g_variable))
    
    # 获取tf.Vaiable定义变量的变量名称
    variable_in_model = [v.name for v in tf.trainable_variables()]
    print("Trainable variables name in modle: {}".format(variable_in_model))
    # 上下文管理器:获取默认图
    g = tf.get_default_graph()
    # 获取模型中的:操作
    operations_in_model = g.get_operations()
	print("Operations in model: {}".format(operations_in_model))
    # 通过操作名称获取操作
    name_operation_placeholder = g.get_operation_by_name('x')
    name_operation_outputs_placeholder = g.get_operation_by_name('x').outputs
    name_operation_outputs_value_placeholder = g.get_operation_by_name('x').outputs[0]
    print("operation name: {}".format(name_operation_placeholder))
    print("operation outputs: {}".format(name_operation_outputs_placeholder))
    print("operation outputs value: {}".format(name_operation_outputs_value_placeholder))
    
    name_operation = g.get_operation_by_name('v_1')
    name_operation_outputs = g.get_operation_by_name('v_1').outputs
    name_operation_outputs_value = g.get_operation_by_name('v_1').outputs[0]
    print("operation name: {}".format(name_operation))
    print("operation outputs: {}".format(name_operation_outputs))
    print("operation outputs value: {}".format(name_operation_outputs_value))
    
    g = tf.get_default_graph()
    print("Result: {}".format(sess.run(g.get_tensor_by_name("add:0"))))
    print("Result: {}".format(sess.run(g.get_tensor_by_name("v_1:0"))))
    print("Result: {}".format(sess.run(g.get_tensor_by_name("v_2:0"))))

【Result】

Model path: ./test_model/plus_model.ckpt
INFO:tensorflow:Restoring parameters from ./test_model/plus_model.ckpt
Trainable variables in model: [<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>]
Global variables in model: [<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'global_step:0' shape=() dtype=int32_ref>]
Trainable variables name in modle: ['v_1:0', 'v_2:0']
Operations in model: [<tf.Operation 'x' type=Placeholder>, <tf.Operation 'Const' type=Const>, <tf.Operation 'v_1' type=VariableV2>, <tf.Operation 'v_1/Assign' type=Assign>, <tf.Operation 'v_1/read' type=Identity>, <tf.Operation 'Const_1' type=Const>, <tf.Operation 'v_2' type=VariableV2>, <tf.Operation 'v_2/Assign' type=Assign>, <tf.Operation 'v_2/read' type=Identity>, <tf.Operation 'global_step/initial_value' type=Const>, <tf.Operation 'global_step' type=VariableV2>, <tf.Operation 'global_step/Assign' type=Assign>, <tf.Operation 'global_step/read' type=Identity>, <tf.Operation 'add' type=Add>, <tf.Operation 'save/Const' type=Const>, <tf.Operation 'save/SaveV2/tensor_names' type=Const>, <tf.Operation 'save/SaveV2/shape_and_slices' type=Const>, <tf.Operation 'save/SaveV2' type=SaveV2>, <tf.Operation 'save/control_dependency' type=Identity>, <tf.Operation 'save/RestoreV2/tensor_names' type=Const>, <tf.Operation 'save/RestoreV2/shape_and_slices' type=Const>, <tf.Operation 'save/RestoreV2' type=RestoreV2>, <tf.Operation 'save/Assign' type=Assign>, <tf.Operation 'save/Assign_1' type=Assign>, <tf.Operation 'save/Assign_2' type=Assign>, <tf.Operation 'save/restore_all' type=NoOp>]
operation name: name: "x"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_INT32
  }
}
attr {
  key: "shape"
  value {
    shape {
      unknown_rank: true
    }
  }
}

operation outputs: [<tf.Tensor 'x:0' shape=<unknown> dtype=int32>]
operation outputs value: Tensor("x:0", dtype=int32)
operation name: name: "v_1"
op: "VariableV2"
attr {
  key: "container"
  value {
    s: ""
  }
}
attr {
  key: "dtype"
  value {
    type: DT_INT32
  }
}
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: 1
      }
    }
  }
}
attr {
  key: "shared_name"
  value {
    s: ""
  }
}

operation outputs: [<tf.Tensor 'v_1:0' shape=(1,) dtype=int32_ref>]
operation outputs value: Tensor("v_1:0", shape=(1,), dtype=int32_ref)
Result: [250]
Result: [125]
Result: [125]

【Analysis】
(1) 使用模型时,需要先载入模型;
(2) 通过tf.trainable_variables获取模型的Variable定义的可用于训练的变量,[<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>],新的图中定义了4个变量,可用于训练的只有2个,即v_1v_2,虽然global_step是Variable定义的,但是trainable=False声明该变量不能用于训练;
(3) 通过tf.global_variables获取全部tf.Variable定义的变量(节点),结果[<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'global_step:0' shape=() dtype=int32_ref>],共有三个变量,v_1,v_2global_step;
(4) 通过tf.get_operations()获取模型的所有操作,即新图定义的所有节点的操作,新图定义的节点有x,v_1, v_2和global_step,结果:[<tf.Operation 'x' type=Placeholder>, <tf.Operation 'Const' type=Const>, <tf.Operation 'v_1' type=VariableV2>, <tf.Operation 'v_1/Assign' type=Assign>, <tf.Operation 'v_1/read' type=Identity>, <tf.Operation 'Const_1' type=Const>, <tf.Operation 'v_2' type=VariableV2>, <tf.Operation 'v_2/Assign' type=Assign>, <tf.Operation 'v_2/read' type=Identity>, <tf.Operation 'global_step/initial_value' type=Const>, <tf.Operation 'global_step' type=VariableV2>, <tf.Operation 'global_step/Assign' type=Assign>, <tf.Operation 'global_step/read' type=Identity>, <tf.Operation 'add' type=Add>, <tf.Operation 'save/Const' type=Const>, <tf.Operation 'save/SaveV2/tensor_names' type=Const>, <tf.Operation 'save/SaveV2/shape_and_slices' type=Const>, <tf.Operation 'save/SaveV2' type=SaveV2>, <tf.Operation 'save/control_dependency' type=Identity>, <tf.Operation 'save/RestoreV2/tensor_names' type=Const>, <tf.Operation 'save/RestoreV2/shape_and_slices' type=Const>, <tf.Operation 'save/RestoreV2' type=RestoreV2>, <tf.Operation 'save/Assign' type=Assign>, <tf.Operation 'save/Assign_1' type=Assign>, <tf.Operation 'save/Assign_2' type=Assign>, <tf.Operation 'save/restore_all' type=NoOp>],其中,变量类型的节点(操作)有Placeholder,VariableV2,和Add,分别对应变量x(Placeholder),v_1(VariableV2),v_2(VariableV2),global_step(VariableV2)因为在Variable中嵌入了const所以,有Contst节点;而模型中总共定义了有四个节点(张量),返回只有三个,少了y(Placeholder),说明了该方法只载入模型参数,未导入图结构,新建的图结构有多少,就会返回多少张量操作;
(4) 通过tf.get_default_graph.get_operation_by_name('x')获取节点x的图结构;

name: "x"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_INT32
  }
}
attr {
  key: "shape"
  value {
    shape {
      unknown_rank: true
    }
  }
}

(5) 通过tf.get_default_graph.get_operation_by_name('x').outputs获取x节点的张量列表,结果: [<tf.Tensor 'x:0' shape=<unknown> dtype=int32>]
(6) 通过tf.get_default_graph.get_operation_by_name('x').outputs[0]获取节点x的张量Tensor("x:0", dtype=int32),通过该张量可进行赋值;如下解析:
【Demo】

import tensorflow as tf
x = tf.placeholder(tf.int32, name="x")
print("Node x placeholer: {}".format(x))

【Result】

Node x placeholer: Tensor("x:0", dtype=int32)

(7) 通过tf.get_default_graph.get_tensor_by_name("add:0")获取模型中的张量值,如:Result: [250]

3.2 载入模型的参数和图结构

【Demo】

import tensorflow as tf
tf.reset_default_graph()

with tf.Session() as sess:
    saver = tf.train.import_meta_graph("test_model/plus_model.ckpt.meta")
    model_file = tf.train.latest_checkpoint("test_model/")
    saver.restore(sess, model_file)
    # 获取模型中tf.Variable定义的变量
    variable_in_model = tf.trainable_variables()
    print("Tensor in model: {}".format(variable_in_model))
    # 获取全局变量
    g_variable = tf.global_variables()
    print("Global variables in model: {}".format(g_variable))
    
    # 获取tf.Vaiable定义变量的变量名称
    variable_in_model = [v.name for v in tf.trainable_variables()]
    print("Trainable variables name in modle: {}".format(variable_in_model))
    # 上下文管理器:获取默认图
    g = tf.get_default_graph()
    # 获取模型中的:操作
    operations_in_model = g.get_operations()
#     print("Operations in model: {}".format(operations_in_model))
    # 通过操作名称获取操作
    name_operation_placeholder = g.get_operation_by_name('x')
    name_operation_outputs_placeholder = g.get_operation_by_name('x').outputs
    name_operation_outputs_value_placeholder = g.get_operation_by_name('x').outputs[0]
    print("operation name: {}".format(name_operation_placeholder))
    print("operation name: {}".format(name_operation_outputs_placeholder))
    print("operation name: {}".format(name_operation_outputs_value_placeholder))

    
    name_operation = g.get_operation_by_name('v_1')
    name_operation_outputs = g.get_operation_by_name('v_1').outputs
    name_operation_outputs_value = g.get_operation_by_name('v_1').outputs[0]
    print("operation name: {}".format(name_operation))
    print("operation name: {}".format(name_operation_outputs))
    print("operation name: {}".format(name_operation_outputs_value))
    
    g = tf.get_default_graph()
    print("Result: {}".format(sess.run(g.get_tensor_by_name("add:0"))))
    print("Result: {}".format(sess.run(g.get_tensor_by_name("v_1:0"))))
    print("Result: {}".format(sess.run(g.get_tensor_by_name("v_2:0"))))

【Result】

INFO:tensorflow:Restoring parameters from test_model/plus_model.ckpt
Tensor in model: [<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>]
Global variables in model: [<tf.Variable 'v_1:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'v_2:0' shape=(1,) dtype=int32_ref>, <tf.Variable 'global_step:0' shape=() dtype=int32_ref>]
Trainable variables name in modle: ['v_1:0', 'v_2:0']
operation name: name: "x"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_INT32
  }
}
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: 1
      }
    }
  }
}

operation name: [<tf.Tensor 'x:0' shape=(1,) dtype=int32>]
operation name: Tensor("x:0", shape=(1,), dtype=int32)
operation name: name: "v_1"
op: "VariableV2"
attr {
  key: "container"
  value {
    s: ""
  }
}
attr {
  key: "dtype"
  value {
    type: DT_INT32
  }
}
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: 1
      }
    }
  }
}
attr {
  key: "shared_name"
  value {
    s: ""
  }
}

operation name: [<tf.Tensor 'v_1:0' shape=(1,) dtype=int32_ref>]
operation name: Tensor("v_1:0", shape=(1,), dtype=int32_ref)
Result: [250]
Result: [125]
Result: [125]

【Analysis】
(1) 与3.1节中不同的是,该模式下,会载入模型的全部参数和图结构,因此,使用tf.get_operations()会获取所有节点的操作;
(2) 其他与3.1节相同;

4 模型预测准确度

【Demo】

import tensorflow as tf
tf.reset_default_graph()
pre = np.array([[0, 1, 3], [0, 3, 2], [1, 0, -1], [10, 1, 3]])
label = np.array([2, 1, 0, 0])
pre_t = tf.convert_to_tensor(pre, tf.float32)
label_t = tf.convert_to_tensor(label, tf.int32)
correct = tf.nn.in_top_k(pre_t, label_t, 1)
correct_num = tf.cast(correct, tf.float32)
accuracy = tf.reduce_mean(correct_num)
with tf.Session() as sess:
    correct, correct_num, accuracy = sess.run([correct, correct_num, accuracy])
    print(correct)
    print(correct_num)
    print(accuracy)

【Result】

[ True  True  True  True]
[1. 1. 1. 1.]
1.0

5 总结

(1) get_chieckpoint_state检查模型的checkpoint文件,获取模型节点信息,即ckpt文件。
(2) restore中的路径信息为模型中ckpt文件的路径,若模型文件中没有checkpoint文件,需要手动添加ckpt文件路径到restore中。
(3) 模型载入有两种模式,仅载入模型参数,载入模型参数和模型结构。
(4) 模型准确度即所有预测结果中预测正确的数量比例,预测值与实际值符合的平均值。


[参考文献]
[1]https://tensorflow.google.cn/guide/saved_model
[2]https://www.cnblogs.com/hellcat/p/6925757.html


  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值