tensorflow saver_TENSORFLOW ??(2)

v2-75f06c4f19fea5790b7c33b1f5596277_1440w.jpg?source=172ae18b

模型生成与使用(ckpt)

保存

# 首先定义saver类

注意点:

  • 创建saver时,可以指定需要存储的tensor,如果没有指定,则全部保存。
  • 创建saver时,可以指定保存的模型个数,利用max_to_keep=4,则最终会保存4个模型(下图中我保存了共4个模型)。
  • saver.save()函数里面可以设定global_step,说明是哪一步保存的模型。
  • 程序结束后,会生成四个文件:存储网络结构.meta、存储训练好的参数.data和.index、记录最新的模型checkpoint。

v2-ebf37d98e8e98921e4899bc196880a35_b.jpg

加载model

def 

注意点:

  • 首先import_meta_graph,这里填的名字meta文件的名字。然后restore时,是检查checkpoint,所以只填到checkpoint所在的路径下即可,不需要填checkpoint,不然会报错“ValueError: Can’t load save_path when it is None.”。
  • 后面根据具体例子,介绍如何利用加载后的模型得到训练的结果,并进行预测。

Tensorflow模型持久化

提供简单的API来还原和保存一个神经网络模型。API ----> tf.train.Saver类

import 

model.ckpt.meta 保存tensorflow计算图的结构 model.ckpt保存每个变量的取值 checkpoint保存一个目录下所有模型文件的列表

import 

比如在加载模型的代码中使用saver = tf.train.Saver([V1])命令来构建tf.train.Saver类,只加载变量V1会加载进来

ERROR 

重新定义神经网络中的变量

#使用一个字典来重新命名变量就可以加载原来的模型了。这个字典指定原来名称V1 的变量现在加载到变量v1中(other-v1),

训练好的.pb模型文件导出存入 以及 使用

import 

通过以下程序可以直接计算定义的加法运算的结果。当只需要得到计算图的中的某个节点的取值的时候

import 

Tensorflow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图的节点所需要的元数据

import 

meta_info_def属性

MetaInfoDef中的tensorflow_version和tensorflow_git_version属性记录了当前计算图中的他tensorflow的版本

graph_def属性

该属性运算的连接结构,由GraphDef Protocal Buffer定义的,GraphDef主要包含一个NodeDef类型的列表

saver_def属性

记录了持久化模型时需要用的一些参数

filename_tensor_name属性给出了保存文件名的张量名称

restore_op_name属性指定运算的名称

max_to_keep属性和keep_checkpoint_every_n_hours属性设定了tf.train.Saver类清理之前保存的模型的策略

collection_def属性

是一个从集合名称到集合内容的映射,其中集合名称为字符串

NodeList用于维护计算图上节点的集合

BytesList可以维护字符串或者系列化之后的Procotol Buffer的集合

model.ckpt.index和model.ckpt.data--of-文件保存所有变量的取值

如何使用tf.train.NewCheckpointReader类

import tensorflow as tf
​
#tf.train.NewCheckpointReader可以读取checkpoint文件中保存所有的变量
#注意后面的.data .index可以省略
reader = tf.train.NewCheckpointReader('/Path/to/model/model.ckpt')
​
#获取所有变量列表。这个是一个变量名到变量纬度的字典
global_variables = reader.get_variable_to_shape_map()
for variable_name in global_variables:
    #variable_name为变量名称,global_variables[variable_name]为变量的维度
    print(variable_name, global_variable[variable_name])
​
#获取名称为V1的变量取值
print("Value for variable v1 is ",reader.get_tensor("v1"))

mnist_inference.py

定义前向传播的过程以及神经网络中的参数

# -*- coding:utf-8 -*-

mnist_train.py

定义神经网络训练过程

# -*- coding:utf-8 -*-

mnist_inference.py

定义测试的过程

# -*- coding:utf-8 -*-

图形识别与卷积神经网络

图形识别问题简介及经典数据集

全连接神经网络 卷积神经网络 循环神经网络

一个卷积层的前向传播过程

filter_weight 

池化层的过滤器除了在长和宽的维度上移动,还需要在深度的维度上移动。最大池化层的前向传播

#tf.nn.max_pool实现最大池化层的前向传播过程,它的参数和tf.nn.conv2d函数类似

tf.nn.max_pool函数和tf.nn.arg_pool函数用法一致

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值