TensorFlow api

0 写在前面

参考资料:

  1. TensorFlow官方文档

1 源节点(source op)

tensorflow计算图中的节点被称之为 op (operation 的缩写).源 op 不需要任何输入,源 op 的输出被传递给其它 op 做运算.

2.2 Tensor

TensorFlow 程序使用 tensor 数据结构来代表所有的数据, 计算图中, 操作间传递的数据都是 tensor. 你可以把 TensorFlow tensor 看作是一个 n 维的数组或列表. 一个 tensor 包含一个静态类型 rank, 和 一个 shape.

2.2.1 占位符

.placeholder(dtype='类型',shape=,name='')

dtype:常见的有,tf.float32(小数默认)/64tf.int8/16/32(整数默认)/64tf.uint8,float,tf.bool,tf.complex64/128float
shape:如[None, 784]None表示无限制。

2.2.2 常量

tf.constant(value,dtype=,name='',shape=)

value:常数或常数矩阵或字符串

2.2.3 变量

tf.Variable(初始值,name='',trainable=True)

初始值:见2.2.5
trainable:若设置为False则表示该变量不在训练过程中被改变值。

y = tf.nn.softmax(tf.matmul(x,W) + b)

2.2.4 函数

2.2.4.1 简单函数

  1. tf.equal(a,b):相等为1,不等为0.
  2. tf.sub(x, a):减法
  3. tf.add(node1, node2):加法
  4. tf.assign(old_value, new_value):分配值
  5. tf.mul(input1, intermed):数字乘法
  6. tf.substract(node1,node2):减法
  7. tf.square():平方
  8. tf.pow(x,n):n次方
  9. tf.argmax(node,axis):返回最大值的index
  10. tf.cast(node,type):转换数据类型。

2.2.4.2 矩阵计算函数

  1. tf.reduce_sum(y_*tf.log(y)):求和
  2. tf.reduce_mean(tf.square(y)):求均值
  3. tf.matmul(x,w,name=''):矩阵乘法

2.2.5 初始值

2.2.5.1 随机值

  1. tf.truncated_normal([IMAGE_PIXELS,hidden1_units],stddev=1.0/math.sqrt(float(IMAGE_PIXELS))),name='weights':正态随机初始值
  2. tf.random_normal(shape=, stddev=0.35,name=):正态分布随机数
  3. tf.random_uniform

2.2.5.2 固定值

  1. tf.zeros:0初始化
  2. 4.0:固定值4.0,float32.

2.2.5.3 由其它节点赋值

  1. tf.Variable(tf.random_normal([784, 200], stddev=0.35),name="weights").initialized_value()

2.2.6 变量初始化节点

必须对变量执行初始化操作。

  1. init = tf.initialize_all_variables():全部初始化
  2. x.initializer():对x初始化
  3. tf.global_variables_initializer():全部初始化

二、高级节点

2.1 激活函数

tf.nn.relu()
tf.nn.softmax()

2.2 损失函数

loss_function=tf.reduce_mean(tf.pow(y-pred,2))
loss_function=tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)):自定义交叉熵损失函数,若pred接近0则会引起异常。
tf.nn.softmax_cross_entropy_with_logits(logits=,label=):结合softmax使用的交叉熵

2.3 优化器

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function):梯度下降优化器,优化最小损失函数
在这里插入图片描述

2.4 .nn

2.4.1 卷积函数

  1. conv2d(input,filter,strides,padding,use_cudnn_on_gpu=None,name=):二维卷积

input:输入的数据,需要是四维的[batch,in_height,in_width,in_channels],要求float32或64。
filter:过滤器,也是四维的[filter_height,filter_width,in_channels,out_channels],同样要求float32或64。
strides:输入数据每一维的步长,是一个长度为四的向量,如[1,1,1,1]
paddingSAME全尺寸填充/VALID不填充。
use_cudnn_on_gpu:bool类型,是否使用cudnn加速。

  1. depthwise_conv2dinput,filter,strides,name=()
  2. separable_conv2d(input,depthwise_filter,pointwise_filter,strides,padding,use_cudnn_on_gpu=None,name=)

2.4.2 池化函数

  1. tf.nn.max_pool(input,ksize,strides,padding,name=)

input:一般就是卷积的结果
ksize:池化窗口的大小,[1,height,width,1]
strides:步长,一般和ksize一致

  1. tf.nn.avg_pool(input,ksize,strides,padding,name=)

2.5 SNN/FC相关

2.5.1 dropout

tf.nn.dropout(node,keep_prob=0.8)

三、会话管理

3.1 会话初始化

  1. sess = tf.Session():创建Session会话对象
  2. sess = tf.InteractiveSession():交互式会话,为了便于使用诸如 IPython 之类的 Python 交互环境, 可以使用 InteractiveSession 代替 Session 类, 使用node.eval()node.run()方法代替 Session.run(). 这样可以避免使用一个变量来持有会话.例如
sess = tf.InteractiveSession()
print(node.eval())
sess.close()

3.2 会话基本用法

  1. sess.run(node):执行节点
  2. node = sess.run([node1, node2,...]):执行多个节点
  3. sess.run([output], feed_dict={a:7.0, b:2.0}):Feed提供占位符的值。
  4. result1,result2=sess.run([output], feed_dict={a:[7.0,3.0], b:[2.0.5.0]}):Feed提供多次操作的占位符的值。

3.3 关闭会话

Session 对象在使用完后需要关闭以释放资源.

  1. sess.close()
#除了显式调用 close 外, 也可以使用 "with" 代码块 来自动完成关闭动作.
with tf.Session() as sess:
  result = sess.run(node)
  print result

3.4 指定会话

sess=tf.Session()
with sess.as_default():
	print(node.eval())
#上下两条语句是等价的
print(node.eval(session=sess))

四、保存和加载

用于大样本无法一次性训练完的情况。

4.1 保存

4.1.1 生成检查点文件(checkpoint file)

文件扩展名.ckpt。
saver = tf.train.Saver()
saver = tf.train.Saver({"my_v2": v2}):指定要保存的变量

在会话中save_path = saver.save(sess, "path")

4.1.2 生成图协议文件(graph proto file)

文件扩展名.bp。
tf.train.write_graph(graph,path,name,as_text=True)

as_text

例子

with tf.Session() as sess:
	tf.train.write_graph(sess.graph_def,'path','test_pb.pb',as_text=False)

4.2 加载

4.2.1 加载检查点文件

saver.restore(sess, "Path")

4.2.2 加载图

tf.import_graph_def(graph,name)

五、TensorBoard管理

TensorFlow Python 库有一个默认图 (default graph), op 构造器可以为其增加节点. 这个默认图对 许多程序来说已经足够用了.

  1. tf.reset_default_graph():重置计算图

5.1 tf.summary

summary用于生成TensorBoard日志,tensorboard使用方法在5.2中讲述。

5.1.1 writer

  1. writer=tf.summary.FileWriter('path',1):生成一个写日志的writer,并写入计算图

1tf.get_default_graph()会话默认图/sess.graph会话计算图

  1. writer.add_summary():在会话执行完日志节点后可以写入日志结点的数据。
  2. writer.close()

5.1.2 日志节点

  1. tf.summary.scalar('name',scalar):用于记录标量节点,如损失函数。
  2. merged=tf.summary.merge_all():合并日志文件,方便一次性写入。
  3. tf.summary.image('name',image,k)

k:最多显示k张图片

  1. tf.summary.histogram('name',scalar)
  2. .audio
  3. ``

5.2 启动TensorBoard

Anaconda Prompt中,先进入日志存放的文件夹,然后输入命令
tensorboard --logdir=/logpath

5.3 命名空间

with tf.name_scope('name')

六、迭代训练

七、评估

定义准确率:

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

8 练习用数据集

8.1 MNIST

MNIST手写数字识别数据集

import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

在这里插入图片描述

已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 技术工厂 设计师:CSDN官方博客 返回首页