1. Graph对象
Graph对应tensorflow中的数据流图
**使用默认的Graph**
import tensorflow as tf
import numpy as np
a=tf.constant([5,3],name="input_a")
b=tf.reduce_prod(a,name="reduce_prod")
c=tf.reduce_sum(a,name="reduce_sum")
d=tf.add(b,c,name="add_d")
g=tf.Graph()
g1=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
t1=np.array([1,3,2],dtype=np.int32)
t2=np.array([3,4,9],dtype=np.int32)
t3=tf.add(t1,t2)
print(sess.run(t3))
print(sess.run(tf.shape(t3)))
sess.close()
**不使用默认的Graph对象**
import tensorflow as tf
import numpy as np
g=tf.Graph()
g1=tf.Graph()
with g.as_default():
a=tf.constant([5,3],name="input_a")
b=tf.reduce_prod(a,name="reduce_prod")
c=tf.reduce_sum(a,name="reduce_sum")
d=tf.add(b,c,name="add_d")
sess=tf.Session(graph=g1)#当使用多个数据图时,最好的方式是显式传入希望运行的对象,而非在一个with语句块中创建Session对象
with g1.as_default():
t1=np.array([1,3,2],dtype=np.int32)
t2=np.array([3,4,9],dtype=np.int32)
t3=tf.add(t1,t2)
print(sess.run(t3))
print(sess.run(tf.shape(t3)))
sess.close()
2.Session对象
Session类负责数据流图的执行,tf.Session()可接收3种参数,graph参数可传入要加载的Graph对象。
调用tf.Session()创建Session对象:
sess=tf.Session()
sess.run()方法可以接收:
fetches参数:如果执行对象是Tensor对象,run()方法输出为Numpy数组;如果对象为Op(指向某个Op的句柄),则输出为None。
feed_dict参数:用于覆盖数据流图中的Tensor对象,输入为python字典形式{键:值},其中的键为数据流图中Tensor对象的句柄,值的类型必须与该Tensor对象的相同。
import tensorflow as tf
import numpy as np
g=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
a=tf.constant([5,6],name="a")
b=tf.add(a,[2,3])
replace_dict={a:[2,2]}
print(sess.run(b,feed_dict=replace_dict))
Session对象使用结束时要调用其close()方法将不再需要的资源释放。也可以将Session()对象作为上下文管理器使用,这样当代码离开其作用域后,该Session对象将自动关闭
with tf.Session() as sess:
#运行数据流图
…#该部分结束后Session()对象自动关闭
也可以利用Session类的as_default()方法将Session对象作为上下文管理器使用,可被某些函数自己使用。这些函数常见的有Operation.run(),Tensor.eval(),调用这些函数相当于将他们直接传入Session.run()函数,使用结束后需要手动关闭Session对象。
import tensorflow as tf
import numpy as np
g=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
a=tf.constant([5,6],name="input_a")
with sess.as_default():
print(a.eval())
sess.close()#必须手动关闭Session对象
3.占位节点
利用tf.placeholder()创建占位符,其中dtype属性是必须指明的,而shape()属性可选
import tensorflow as tf
import numpy as np
g=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
a=tf.placeholder(tf.int32,shape=[2],name="input_a")#创建一个长度为2数据类型为int32的占位向量
b=tf.placeholder(tf.int32,shape=None,name="input_b")#此处shape的值为None,表示可以接收任意形状的Tensor
c=tf.add(a,b)
with sess.as_default():
input_dict={a:np.array([1,2],dtype=np.int32),b:np.array([3,4],dtype=np.int32)}
print(sess.run(c,feed_dict=input_dict))
sess.close()#必须手动关闭Session对象
4.Variable对象
Variable对象保存可变的张量值。为使用Variable对象,必须对其进行初始化,将所有Variable对象重置为初始值,调用tf.initialize_variables()。只对部分Variable对象重置初始化,调用tf.initialize_variables()参数传入一个要初始化的Variable对象列表。Variable对象通过Variable.assign_add()和Variable.assign_sub()方法进行自增和自减操作。Variable对象可以被多个独立不同的Session对象出事话,每个Session维护自己的Variable值。
TensorFlow提供大量辅助Op,如
tf.zeros([2,2])创建2×2的零矩阵
tf.ones([6])创建长度为6的全1向量
tf.random_uniform([2,2],minval=0,maxval=10)创建2×2的张量,元素服从0-10的均匀分布
tf.random_normal([2×2],means=0.0,stddev=2.0)创建2×2的张量,元素服从0均值标准差为2的正态分布
tf.truncated_normal([2,2],means=5.0,stddev=1.0)创建2×2张量,元素值在[3,7]之间(值不会超过2倍标准差)
import tensorflow as tf
import numpy as np
g=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
a=tf.placeholder(tf.int32,shape=[2],name="input_a")
b=tf.placeholder(tf.int32,shape=None,name="input_b")
my_var1=tf.Variable(8,name="my_var")
my_var2=tf.Variable(tf.ones([6]))
my_var3=my_var1.assign(my_var1+2)
my_var4=my_var2.assign(tf.zeros([6]))#修改的值必须与原值相同shape
init=tf.initialize_all_variables()#初始化全部变量
init2=tf.initialize_variables([my_var1])#初始化部分变量,变量以列表形式传入
add1=tf.add(3,my_var1)
c=tf.add(a,b)
zeros=tf.zeros([2,2])
with sess.as_default():
input_dict={a:np.array([1,2],dtype=np.int32),b:np.array([3,4],dtype=np.int32)}
sess.run(init)#Variable对象必须初始化
print(sess.run(c,feed_dict=input_dict))
print(sess.run(add1))
print(sess.run(my_var3))#值为10
print(sess.run(my_var3))#与my_var1值相同,为12
print(sess.run(my_var4))
print(sess.run(my_var2))
print(sess.run(my_var1))
sess.run(init2)
print(sess.run(my_var1))
print(sess.run(my_var1.assign_add(1)))#自增1
print(sess.run(my_var2))
print(sess.run(my_var3))
sess.close()#必须手动关闭Session对象