关于 TensorFlow 的一些零散知识

关于 TensorFlow 的一些零散知识

TensorFlow 中的内容相当繁杂, 及时总结是一个好习惯; 平时会收集/总结一些有用的知识点和代码片段, 放在本篇博文下是很合适的. 嘻嘻, 我就是想水一篇文章… 🤣🤣🤣

广而告之

可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号;另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中;

变量初始化

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())

前两条代码可以处理 Attempting to use uninitialized value 的问题, 最后一条用于处理 LookUpTable not initialized 的问题: 在使用 feature_column 时, 由于 feature 需要查表获取, 这个表也需要进行初始化, 比如:

FailedPreconditionError (see above for traceback): Table not initialized.
         [[node hash_table_Lookup (defined at 5.py:23)  = LookupTableFindV2[Tin=DT_STRING, Tout=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](relationship_lookup/hash_table, to_sparse_input_1/values, relationship_lookup/hash_table/Const)]]

获取 TensorFlow 中变量或者 Op 的 Name

all_vars = tf.global_variables()
for v in all_vars:
    print(v.op.name)

graph = tf.get_default_graph()
for op in graph.get_operations():
    print(op.name)

for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    arr = sess.run(var)

from tensorflow.python.framework import ops
print(tf.get_collection(ops.GraphKeys.MODEL_VARIABLES))

读取 Estimator 对象的 Variable

names = linear_est.get_variable_names()
print('name: ', names)
for i in names:
    print(type(linear_est.get_variable_value(i)))

还有一种方法, 来自: can tf.estimator.Estimator’s parameters be modified by hand?,
通过访问 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 来达到目的, 但如果为了得到模型的权重, 而不是整张图上的变量, 应该访问:

from tensorflow.python.framework import ops
print(tf.get_collection(ops.GraphKeys.MODEL_VARIABLES))

上面链接中的代码如下:

# Restore, Update, Save
# tested only on tesorflow 1.4
import tensorflow as tf
tf.reset_default_graph()

CHECKPOINT_DIR = 'CHECKPOIN_DIR' # for example '/my_checkpoints' as in tf.estimator.LinearClassifier(model_dir='/my_checkpoints'...
checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph(checkpoint.model_checkpoint_path + '.meta')
    saver.restore(sess, checkpoint.model_checkpoint_path)

    # just to check all variables values
    # sess.run(tf.all_variables())

    # get your variable
    KEY = 'linear/linear_model/0/weights/part_0:0'# for tf.estimator.LinearClassifier first weight
    var_wights_0 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == KEY][0]
    sess.run(var_wights_0)

    # your update operation
    var_wights_0_updated = var_wights_0.assign(var_wights_0 - 100)
    sess.run(var_wights_0_updated)

    # you can check that value is updated
    # sess.run(tf.all_variables())

    # this saves updated values to last checkpoint saved by estimator
    saver.save(sess, checkpoint.model_checkpoint_path)

TensorFlow 将整数转化为字符串

使用 tf.string.format, 来自 Tensorflow - How to Convert int32 to string (using Python API for Tensorflow)

import tensorflow as tf

x = tf.constant([1, 2, 3], dtype=tf.int32)
x_as_string = tf.map_fn(lambda xi: tf.strings.format('{}', xi), x, dtype=tf.string)

with tf.Session() as sess:
  res = sess.run(x_as_string)
  print(res)
  # [b'1' b'2' b'3']

tf.data 介绍

tf.identity 的作用

总的来说, 主要是两个, tf.identity 相当于创建了一个和原始结果一样的新节点, 可以和各种控制流的 op 配合使用, 具体看链接中的例子; 另一个是给 op 命名.

推荐资料

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值