作者:qxl 邮箱: 1183129553@qq.com
系列文章链接
一、tensorflow安装方式及问题汇总
二、TensorFlow基础概念
三、神经网络结构设计
四、mnist数字识别问题
五、图像识别与卷积神经网络
六、U-NET网络
七、TensorFlow常用指令记录
文章目录
前言
最近一直在学习tensorflow,一些常用的指令,如果只是看看,经常会遗忘。 ***常用指令
tf.argmax
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
dimension=0 按列找
dimension=1 按行找
tf.argmax()返回最大数值的下标
通常和tf.equal()一起使用,计算模型准确度
tf.cast
类型转换
tf.equal
import tensorflow as tf
a = [[1,2,3],[4,5,6]]
b = [[1,0,3],[1,5,1]]
with tf.Session() as sess:
print(sess.run(tf.equal(a,b)))
输出:
[[ True False True]
[False True False]]
变量管理-tf.name_scope() 和 tf.variable_scope()
tensorflow提供了通过变量名称来创建或者获取变量的机制。通过此机制,在不同函数中可直接通过变量名字来使用变量,而不需要以参数的形式传递。简单来说,就是为了变量共享(传递)。通过变量名称获取变量的机制主要通过tf.get_variable和tf.variable_scope函数来实现的。
下面介绍一下相关概念:
- tf.get_variable方法负责创建和获取指定名称的变量
- tf.Variable在创建变量时,一律创建新的变量,如果这个变量已存在,则后缀会增加0、1、2等数字编号予以区别。对于,tf.Variable()而言,两种域(name_scope 和 variable_scope)没有差别,输出变量名都有前缀。
- tf.variable_scope负责管理命名空间。tf.variable_scope传入的第一个参数为命名空间的名字。需要注意的是reuse这个参数。默认值为False。当参数reuse=False时,在第二次调用get_variable函数的时候,会抛出异常,显示变量已经存在。reuse=True表示共享该作用域内的参数,即可以多次调用。tf.get_variable()方式创建的变量,只有variable_scope名称会加到变量名称前面,而name_scope不会作为前缀
- name_scope暂时不是很了解,后续补充
用tf.get_variable() 而不用**tf.Variable()**原因:为前者拥有一个变量检查机制,会检测已经存在的变量是否设置为共享变量,如果已经存在的变量没有设置为共享变量,TensorFlow 运行到第二个拥有相同名字的变量的时候,就会报错。
代码示例:
import tensorflow as tf
with tf.variable_scope('V1_domain', reuse=tf.AUTO_REUSE):
a1 = tf.get_variable(name='a1', shape=[1], initializer=tf.constant_initializer(1))
a2 = tf.Variable(tf.random_normal(shape=[2, 3], mean=0, stddev=1), name='a2')
with tf.variable_scope('V2_domain', reuse=tf.AUTO_REUSE):
a3 = tf.get_variable(name='a1', shape=[1], initializer=tf.constant_initializer(1))
a4 = tf.Variable(tf.random_normal(shape=[2, 3], mean=0, stddev=1), name='a2')
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(a1.name)
print(a2.name)
print(a3.name)
print(a4.name)
with tf.name_scope('V1_name_domain'):
a11 = tf.get_variable(name='a11', shape=[1], initializer=tf.constant_initializer(1))
a12 = tf.Variable(tf.random_normal(shape=[2, 3], mean=0, stddev=1), name='a12')
with tf.name_scope('V2_name_domain'):
#a13 = tf.get_variable(name='a11', shape=[1], initializer=tf.constant_initializer(1))
a14 = tf.Variable(tf.random_normal(shape=[2, 3], mean=0, stddev=1), name='a12')
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(a11.name)
print(a12.name)
#print(a13.name)
print(a14.name)
输出结果:
V1_domain/a1:0
V1_domain/a2:0
V2_domain/a1:0
V2_domain/a2:0
a11:0
V1_name_domain/a12:0
V2_name_domain/a12:0
tensorflow checkpoint
介绍tensorflow2.x和tensorflow1.x中如何保存checkpoint。
1. 什么是checkpoint?
检查点checkpoint(二进制文件)中存储着模型model所使用的所有tf.Variable对象,不包含关于模型的计算信息。只在我们恢复模型的时候checkpoint才有用,因为不知道模型的结构,只知道变量信息是没有意义的。
保存
tf2.x版本
model.save_weights("path_to_my_tf_checkpoints")
tf1.x版本
saver = tf.train.Saver()
saver.saver(sess, model_path+model_name)
一般格式:外层创建saver对象,session里边训练之后保存
with tf.Session() as sess:
...training...
saver.save(sess, modelpath+modelname)
saver类介绍
先看一下函数原型,
tf.train.Saver(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
常用的参数介绍:
- var_list: 保存的变量Variable列表,也可以是一个字典映射,表示我们需要保存哪一些变量,默认情况下不用制定,表示保存所有的变量。
- max_to_keep:保存最近的几份检查点,默认是5,及保存最后5分检查点。
- keep_checkpoint_every_n_hours=10000.0:隔多少个小时保存一次检查点,默认是10000小时
saver类的属性和常用方法
Saver类的属性和常用方法:
Saver类的属性
- last_checkpoints
Saver类的方法:
- as_saver_def
- build
- export_meta_graph
- from_proto
- recover_last_checkpoints
- restore
- save
- set_last_checkpoints
- set_last_checkpoints_with_time
- to_proto
保存模型save方法查看
save(
sess,
save_path,
global_step,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True,
strip_default_attrs=False,
save_debug_info=False
)
return:
#返回checkpoint文件保存的文件夹地址,这个地址可以直接在restore恢复模型的时候使用
参数解析:
- sess:保存变量的会话对象
- save_path: 文件名称,保存checkpoint文件的完整路径,注意,这里是完整文件路径,不是文件夹
- global_step:它会作为checkpoint文件的一个后缀
- meta_graph_suffix:图的结构文件的后缀,默认是“meta", 这个是可以更改的
- write_meta_graph:是否写入graph的meta文件
- write_state:bool类型表示写入是否成功
保存文件的实例
# 保存模型
saver = tf.train.Saver()
# 会话GPU的相关配置
# config 的有关配置
with tf.Session(config = config) as sess:
for epoch in range(epochs):
for i in range(train_batch_count):
# 训练代码
# 每一次 epoch 结束之后保存模型 ,添加global_step参数
save_path = saver.save(sess, "./ckpt_model/keypoint_model.ckpt",global_step=epoch)
print("model has saved,saved in path: %s" % save_path)
需要注意两个地方:
- 每一个文件在原本指定的名称,即“keypoint_mode.dkpt”后面多了一个后缀数字,这个数字就是gobal_step指定的数字
- 因为最大保存数目是5, 所有数字只有5,6,7,8,9
什么时候使用global_step参数呢?
- 在每一个epoch之后,这样global_step = epoch
- 在每一个epoch内部,每隔step保存一次,globa_step = step
checkpoint文件介绍
参考
https://blog.csdn.net/qq_27825451/article/details/105819752