![v2-75f06c4f19fea5790b7c33b1f5596277_1440w.jpg?source=172ae18b](http://img-03.proxy.5ce.com/view/image?&type=2&guid=a7b5ab50-e02e-eb11-8da9-e4434bdf6706&url=https://pic1.zhimg.com/v2-75f06c4f19fea5790b7c33b1f5596277_1440w.jpg?source=172ae18b)
模型生成与使用(ckpt)
保存
# 首先定义saver类
注意点:
- 创建saver时,可以指定需要存储的tensor,如果没有指定,则全部保存。
- 创建saver时,可以指定保存的模型个数,利用max_to_keep=4,则最终会保存4个模型(下图中我保存了共4个模型)。
- saver.save()函数里面可以设定global_step,说明是哪一步保存的模型。
- 程序结束后,会生成四个文件:存储网络结构.meta、存储训练好的参数.data和.index、记录最新的模型checkpoint。
![v2-ebf37d98e8e98921e4899bc196880a35_b.jpg](http://img-03.proxy.5ce.com/view/image?&type=2&guid=a7b5ab50-e02e-eb11-8da9-e4434bdf6706&url=https://pic2.zhimg.com/v2-ebf37d98e8e98921e4899bc196880a35_b.jpg)
加载model
def
注意点:
- 首先import_meta_graph,这里填的名字meta文件的名字。然后restore时,是检查checkpoint,所以只填到checkpoint所在的路径下即可,不需要填checkpoint,不然会报错“ValueError: Can’t load save_path when it is None.”。
- 后面根据具体例子,介绍如何利用加载后的模型得到训练的结果,并进行预测。
Tensorflow模型持久化
提供简单的API来还原和保存一个神经网络模型。API ----> tf.train.Saver类
import
model.ckpt.meta 保存tensorflow计算图的结构 model.ckpt保存每个变量的取值 checkpoint保存一个目录下所有模型文件的列表
import
比如在加载模型的代码中使用saver = tf.train.Saver([V1])命令来构建tf.train.Saver类,只加载变量V1会加载进来
ERROR
重新定义神经网络中的变量
#使用一个字典来重新命名变量就可以加载原来的模型了。这个字典指定原来名称V1 的变量现在加载到变量v1中(other-v1),
训练好的.pb模型文件导出存入 以及 使用
import
通过以下程序可以直接计算定义的加法运算的结果。当只需要得到计算图的中的某个节点的取值的时候
import
Tensorflow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图的节点所需要的元数据
import
meta_info_def属性
MetaInfoDef中的tensorflow_version和tensorflow_git_version属性记录了当前计算图中的他tensorflow的版本
graph_def属性
该属性运算的连接结构,由GraphDef Protocal Buffer定义的,GraphDef主要包含一个NodeDef类型的列表
saver_def属性
记录了持久化模型时需要用的一些参数
filename_tensor_name属性给出了保存文件名的张量名称
restore_op_name属性指定运算的名称
max_to_keep属性和keep_checkpoint_every_n_hours属性设定了tf.train.Saver类清理之前保存的模型的策略
collection_def属性
是一个从集合名称到集合内容的映射,其中集合名称为字符串
NodeList用于维护计算图上节点的集合
BytesList可以维护字符串或者系列化之后的Procotol Buffer的集合
model.ckpt.index和model.ckpt.data--of-文件保存所有变量的取值
如何使用tf.train.NewCheckpointReader类
import tensorflow as tf
#tf.train.NewCheckpointReader可以读取checkpoint文件中保存所有的变量
#注意后面的.data .index可以省略
reader = tf.train.NewCheckpointReader('/Path/to/model/model.ckpt')
#获取所有变量列表。这个是一个变量名到变量纬度的字典
global_variables = reader.get_variable_to_shape_map()
for variable_name in global_variables:
#variable_name为变量名称,global_variables[variable_name]为变量的维度
print(variable_name, global_variable[variable_name])
#获取名称为V1的变量取值
print("Value for variable v1 is ",reader.get_tensor("v1"))
mnist_inference.py
定义前向传播的过程以及神经网络中的参数
# -*- coding:utf-8 -*-
mnist_train.py
定义神经网络训练过程
# -*- coding:utf-8 -*-
mnist_inference.py
定义测试的过程
# -*- coding:utf-8 -*-
图形识别与卷积神经网络
图形识别问题简介及经典数据集
全连接神经网络 卷积神经网络 循环神经网络
一个卷积层的前向传播过程
filter_weight
池化层的过滤器除了在长和宽的维度上移动,还需要在深度的维度上移动。最大池化层的前向传播
#tf.nn.max_pool实现最大池化层的前向传播过程,它的参数和tf.nn.conv2d函数类似
tf.nn.max_pool函数和tf.nn.arg_pool函数用法一致