tensorflow使用记录

以vc版本的tensorpack说明

一.tf.Variable和tf.get_variable的区别

 在 tf.name_scope() 环境下分别使用 tf.get_variable() 和 tf.Variable()创建变量

  1. 创建变量时,name属性值允许重复,检查到相同名字的变量时,由自动别名机制创建不同的变量
with tf.name_scope('a_name_scope'):
    tf.Variable(name=None, initial_value, validate_shape=True, trainable=True, collections=None)
  1. 会先查找已经创建的变量,如果有同名的变量,当变量性质为共享变量时,就返回已存在的变量;如果为非共享变量,会报错
with tf.name_scope('a_name_scope',reuse=True):
    tf.get_variable(name, shape=None, initializer=None, dtype=tf.float32, trainable=True, collections=None)

参考:https://blog.csdn.net/lanchunhui/article/details/61914287

模型调用

  • 每次运行,会有checkpoint、graph、model生成
    1⃣️其中,若文件夹已经有checkpoint,且写有自动掉用上次模型,可以在上次的基础上继续训练,否则重新生成,且不能调用之前的模型,即使已经存在
    2⃣️每次运行会重新生成graph,即使上次的已经存在,因此调用上次模型与文件夹中是否有graph无关

权重变量查看

import numpy as np

import tensorflow as tf
import sys

model = sys.argv[1]
tensor = sys.argv[2]
reader = tf.train.NewCheckpointReader(model)
all_variables = reader.get_variable_to_shape_map()

#reader = pywrap_tensorflow.NewCheckpointReader(ckpt_path)
#param_dict = reader.get_variable_to_shape_map()

for key, val in all_variables.items():
    try:
        print key, val
        #key是网络参数名,val是维度
    except:
        pass
w0 = reader.get_tensor(tensor)
np.save('con1d_w.npy',w0)
print(type(w0))
print(w0.shape)
print(w0[0])


参考:
【1】TensorFlow查看ckpt中变量的几种方法https://blog.csdn.net/u014061630/article/details/80461044
【2】https://blog.csdn.net/liujianlin01/article/details/80426348
【3】https://blog.csdn.net/hustchenze/article/details/83625223

文件内容

  • chekpoint—记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model
    在这里插入图片描述
  • MyModel.meta文件保存的是图结构,meta文件是pb(protocol buffer)格式文件,包含变量、op、集合等。
  • ckpt文件是二进制文件,保存了所有的weights、biases、gradients等变量。在tensorflow 0.11之前,保存在.ckpt文件中。0.11后,通过两个文件保存,如:
MyModel.data-00000-of-00001
MyModel.index

tf.nn.dynamic_rnn 详解

参考: https://zhuanlan.zhihu.com/p/43041436

output, last_state = tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)
nameshape
cellint, lstm or gru的神经元数,与输出size有关
input[batch_size, max_length, embedding_size]
sequence_length[int, int,…]对应输入序列的实际长度,应用于padding的非定长输入
output[batch_size, max_length, cell]
state[batch_size, cell.output_size ] or [2, batch_size, cell.output_size ]

output 和state的关系

在这里插入图片描述
在这里插入图片描述
以上两个图是lstm的结构,对应的last_state有【 c t , h t c_t, h_t ct,ht】,cell_state(应该记住或遗忘的状态), h t h_t ht(实际的输出),因此state是【2, batch_size, cell】
c t c_t ct对应中间的每一个状态【batch_size, max_length, cell_size】
last_state中的 h t h_t ht对应的是output中最后一个输出(每一个输入最后一个不为0的部分)

例如:输入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【2,3,5】

在这里插入图片描述
GRU是LSTM修改的RNN,对应只有一个输出,以及向后层传递的 h t h_t ht,所以state=【batch_size, cell_size】

同理,对于gru,例如:输入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【3,5】

tensorflow模型保存

  • 保存了3个文件
model.ckpt-10000.data-00000-of-00001                         
model.ckpt-10000.index                                    
model.ckpt-10000.meta 
  • 一般调用生成的模型,直接model.ckpt-1000这样的格式即可
  • data中存储的是模型的变量值
  • index 存储的是tensor名称
  • meta 存储的是graph结构,包括 GraphDef, SaverDef等,当存在meta file,我们可以不在文件中定义模型,也可以运行,而如果没有meta file,我们需要定义好模型,再加载data file,得到变量值

计算模型中的参数量

  • keras是可以直接输出每层的结构,并且在最后自动计算参数量
  • 普通的tensorflow可以调用训练生成的模型,计算参数量
from tensorflow.python import pywrap_tensorflow
import os
import numpy as np

checkpoint_path = os.path.join("models_pretrained/", "model.ckpt-82798")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
total_parameters = 0
for key in var_to_shape_map:#list the keys of the model
    # print(key)
    # print(reader.get_tensor(key))
    shape = np.shape(reader.get_tensor(key))  #get the shape of the tensor in the model
    shape = list(shape)
    # print(shape)
    # print(len(shape))
    variable_parameters = 1
    for dim in shape:
        # print(dim)
        variable_parameters *= dim
    # print(variable_parameters)
    total_parameters += variable_parameters

print(total_parameters)
  1. 计算模型的浮点运算量
    指导方法
    但是还没有成功跑通,暂留!

  2. 日志输出

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]='1' # 这是默认的显示等级,显示所有信息
os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 warning 和 Error 
os.environ["TF_CPP_MIN_LOG_LEVEL"]='3' # 只显示 Error  
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值