TensorFlow常用指令记录

作者:qxl 邮箱: 1183129553@qq.com


系列文章链接

一、tensorflow安装方式及问题汇总
二、TensorFlow基础概念
三、神经网络结构设计
四、mnist数字识别问题
五、图像识别与卷积神经网络
六、U-NET网络
七、TensorFlow常用指令记录



前言

最近一直在学习tensorflow,一些常用的指令,如果只是看看,经常会遗忘。 ***

常用指令

tf.argmax

correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

dimension=0 按列找
dimension=1 按行找
tf.argmax()返回最大数值的下标
通常和tf.equal()一起使用,计算模型准确度

tf.cast

类型转换

tf.equal

import tensorflow as tf
a = [[1,2,3],[4,5,6]]
b = [[1,0,3],[1,5,1]]
with tf.Session() as sess:
    print(sess.run(tf.equal(a,b)))

输出:

[[ True False  True]
 [False  True False]]

变量管理-tf.name_scope() 和 tf.variable_scope()

tensorflow提供了通过变量名称来创建或者获取变量的机制。通过此机制,在不同函数中可直接通过变量名字来使用变量,而不需要以参数的形式传递。简单来说,就是为了变量共享(传递)。通过变量名称获取变量的机制主要通过tf.get_variable和tf.variable_scope函数来实现的。
下面介绍一下相关概念:

  1. tf.get_variable方法负责创建和获取指定名称的变量
  2. tf.Variable在创建变量时,一律创建新的变量,如果这个变量已存在,则后缀会增加0、1、2等数字编号予以区别。对于,tf.Variable()而言,两种域(name_scope 和 variable_scope)没有差别,输出变量名都有前缀。
  3. tf.variable_scope负责管理命名空间。tf.variable_scope传入的第一个参数为命名空间的名字。需要注意的是reuse这个参数。默认值为False。当参数reuse=False时,在第二次调用get_variable函数的时候,会抛出异常,显示变量已经存在。reuse=True表示共享该作用域内的参数,即可以多次调用。tf.get_variable()方式创建的变量,只有variable_scope名称会加到变量名称前面,而name_scope不会作为前缀
  4. name_scope暂时不是很了解,后续补充

tf.get_variable() 而不用**tf.Variable()**原因:为前者拥有一个变量检查机制,会检测已经存在的变量是否设置为共享变量,如果已经存在的变量没有设置为共享变量,TensorFlow 运行到第二个拥有相同名字的变量的时候,就会报错。

代码示例:

import tensorflow as tf

with tf.variable_scope('V1_domain', reuse=tf.AUTO_REUSE):
    a1 = tf.get_variable(name='a1', shape=[1], initializer=tf.constant_initializer(1))
    a2 = tf.Variable(tf.random_normal(shape=[2, 3], mean=0, stddev=1), name='a2')
with tf.variable_scope('V2_domain', reuse=tf.AUTO_REUSE):
    a3 = tf.get_variable(name='a1', shape=[1], initializer=tf.constant_initializer(1))
    a4 = tf.Variable(tf.random_normal(shape=[2, 3], mean=0, stddev=1), name='a2')

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(a1.name)
    print(a2.name)
    print(a3.name)
    print(a4.name)

with tf.name_scope('V1_name_domain'):
    a11 = tf.get_variable(name='a11', shape=[1], initializer=tf.constant_initializer(1))
    a12 = tf.Variable(tf.random_normal(shape=[2, 3], mean=0, stddev=1), name='a12')
with tf.name_scope('V2_name_domain'):
    #a13 = tf.get_variable(name='a11', shape=[1], initializer=tf.constant_initializer(1))
    a14 = tf.Variable(tf.random_normal(shape=[2, 3], mean=0, stddev=1), name='a12')

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(a11.name)
    print(a12.name)
    #print(a13.name)
    print(a14.name)

输出结果:

V1_domain/a1:0
V1_domain/a2:0
V2_domain/a1:0
V2_domain/a2:0
a11:0
V1_name_domain/a12:0
V2_name_domain/a12:0

tensorflow checkpoint

介绍tensorflow2.x和tensorflow1.x中如何保存checkpoint。

1. 什么是checkpoint?

检查点checkpoint(二进制文件)中存储着模型model所使用的所有tf.Variable对象,不包含关于模型的计算信息。只在我们恢复模型的时候checkpoint才有用,因为不知道模型的结构,只知道变量信息是没有意义的。

保存

tf2.x版本
model.save_weights("path_to_my_tf_checkpoints")
tf1.x版本
saver = tf.train.Saver()
saver.saver(sess, model_path+model_name)

一般格式:外层创建saver对象,session里边训练之后保存

with tf.Session() as sess:
	...training...
	saver.save(sess, modelpath+modelname)

saver类介绍

先看一下函数原型,

tf.train.Saver(
    var_list=None, 
    reshape=False, 
    sharded=False, 
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0, 
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

常用的参数介绍:

  • var_list: 保存的变量Variable列表,也可以是一个字典映射,表示我们需要保存哪一些变量,默认情况下不用制定,表示保存所有的变量。
  • max_to_keep:保存最近的几份检查点,默认是5,及保存最后5分检查点。
  • keep_checkpoint_every_n_hours=10000.0:隔多少个小时保存一次检查点,默认是10000小时
saver类的属性和常用方法

Saver类的属性和常用方法:
Saver类的属性

  • last_checkpoints

Saver类的方法:

  • as_saver_def
  • build
  • export_meta_graph
  • from_proto
  • recover_last_checkpoints
  • restore
  • save
  • set_last_checkpoints
  • set_last_checkpoints_with_time
  • to_proto
保存模型save方法查看
save(
	sess, 
	save_path,
	global_step,
	latest_filename=None,
	meta_graph_suffix='meta',
	write_meta_graph=True,
	write_state=True,
	strip_default_attrs=False,
	save_debug_info=False
)

return:
#返回checkpoint文件保存的文件夹地址,这个地址可以直接在restore恢复模型的时候使用

参数解析:

  • sess:保存变量的会话对象
  • save_path: 文件名称,保存checkpoint文件的完整路径,注意,这里是完整文件路径,不是文件夹
  • global_step:它会作为checkpoint文件的一个后缀
  • meta_graph_suffix:图的结构文件的后缀,默认是“meta", 这个是可以更改的
  • write_meta_graph:是否写入graph的meta文件
  • write_state:bool类型表示写入是否成功
保存文件的实例
# 保存模型
saver = tf.train.Saver()
 
# 会话GPU的相关配置
# config 的有关配置
with tf.Session(config = config) as sess:
    for epoch in range(epochs): 
        for i in range(train_batch_count):
            # 训练代码
    
        # 每一次 epoch 结束之后保存模型 ,添加global_step参数  
        save_path = saver.save(sess, "./ckpt_model/keypoint_model.ckpt",global_step=epoch)
        print("model has saved,saved in path: %s" % save_path)

在这里插入图片描述
需要注意两个地方:

  • 每一个文件在原本指定的名称,即“keypoint_mode.dkpt”后面多了一个后缀数字,这个数字就是gobal_step指定的数字
  • 因为最大保存数目是5, 所有数字只有5,6,7,8,9

什么时候使用global_step参数呢?

  • 在每一个epoch之后,这样global_step = epoch
  • 在每一个epoch内部,每隔step保存一次,globa_step = step
checkpoint文件介绍

参考
https://blog.csdn.net/qq_27825451/article/details/105819752

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值