在TensorFlow中
- 使用图 (graph) 来表示计算任务
- 在会话(Session)的上下文中执行图
- 用tensor来表示数据
- 通过变量Variable来对状态进行维护
- 使用feed来为操作赋值或者从中取出数据
下面这段代码可以让你对TensorFlow有个很直观的感受
# -*- coding: UTF-8 -*-
import tensorflow as tf
# 创建一个常量 op, 产生一个 1x2 矩阵. 这个 op 被作为一个节点
# 加到默认图中.
# 构造器的返回值代表该常量 op 的返回值.
matrix1 = tf.constant([[3., 3.]])
# 创建另外一个常量 op, 产生一个 2x1 矩阵.
matrix2 = tf.constant([[2.],[2.]])
# 创建一个矩阵乘法 matmul op , 把 'matrix1' 和 'matrix2' 作为输入.
# 返回值 'product' 代表矩阵乘法的结果.
product = tf.matmul(matrix1, matrix2)
# 启动默认图.
# 调用 sess 的 'run()' 方法来执行矩阵乘法 op, 传入 'product' 作为该方法的参数.
# 上面提到, 'product' 代表了矩阵乘法 op 的输出, 传入它是向方法表明, 我们希望取回
# 矩阵乘法 op 的输出.
#
# 整个执行过程是自动化的, 会话负责传递 op 所需的全部输入. op 通常是并发执行的.
#
# 函数调用 'run(product)' 触发了图中三个 op (两个常量 op 和一个矩阵乘法 op) 的执行.
#
# 返回值 'result' 是一个 numpy `ndarray` 对象.
with tf.Session() as sess:
result = sess.run(product)
print(result)
# ==> [[ 12.]]
1、变量的使用
TensorFlow中的变量和python、java中的变量是类似的,都是代指真实的某个对象。如果使用了变量,则必须要使用下面代码中的第14行进行变量的初始化操作
# -*- coding: UTF-8 -*-
import tensorflow as tf
"""
变量的使用
"""
# 创建变量
a = tf.Variable([2, 3])
b = tf.constant([1, 2])
# op:矩阵相减
sub_res = tf.subtract(a, b)
# Returns an Op that initializes global variables
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
sub_res = sess.run(sub_res)
print(sub_res)
# [1 1]
声明一个变量,实现自增的效果
# -*- coding: UTF-8 -*-
import tensorflow as tf
# 变量自增
zero = tf.Variable(0, name="counter")
# 给zero的值1
plus_res = tf.add(zero, 1)
# assign是分配的意思,将zero的值分配给plus_res
update_value = tf.assign(zero, plus_res)
# 初始化变量
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(sess.run(zero))
for _ in range(10):
sess.run(update_value)
print(sess.run(zero))
2、placeholder的使用
placeholder顾名思义就是占位符,先声明一个变量在这里,你可以在后续的操作里进行使用,但是这会是没值的,这会不想给你,任性~
当启用Session的之后,使用session.run(),在参数列表中,使用feed_dict进行placeholder的赋值。
# -*- coding: UTF-8 -*-
import tensorflow as tf
# feed用法
# 预先设置placeholder,在session.run()中放入实际参数
num_1 = tf.placeholder(tf.float32)
num_2 = tf.placeholder(tf.float32)
res = tf.multiply(num_1, num_2)
with tf.Session() as sess:
res = sess.run(res, feed_dict={num_1: 2, num_2: 4})
print(res)
# 8.0