tensorflow学习(一)

1. Graph对象

Graph对应tensorflow中的数据流图
    **使用默认的Graph**
import tensorflow as tf
import numpy as np
a=tf.constant([5,3],name="input_a")
b=tf.reduce_prod(a,name="reduce_prod")
c=tf.reduce_sum(a,name="reduce_sum")
d=tf.add(b,c,name="add_d")
g=tf.Graph()
g1=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
    t1=np.array([1,3,2],dtype=np.int32)
    t2=np.array([3,4,9],dtype=np.int32)
    t3=tf.add(t1,t2)
    print(sess.run(t3))
    print(sess.run(tf.shape(t3)))
    sess.close()
    **不使用默认的Graph对象**
import tensorflow as tf
import numpy as np
g=tf.Graph()
g1=tf.Graph()
with g.as_default():
    a=tf.constant([5,3],name="input_a")
    b=tf.reduce_prod(a,name="reduce_prod")
    c=tf.reduce_sum(a,name="reduce_sum")
    d=tf.add(b,c,name="add_d")
sess=tf.Session(graph=g1)#当使用多个数据图时,最好的方式是显式传入希望运行的对象,而非在一个with语句块中创建Session对象
with g1.as_default():
    t1=np.array([1,3,2],dtype=np.int32)
    t2=np.array([3,4,9],dtype=np.int32)
    t3=tf.add(t1,t2)
    print(sess.run(t3))
    print(sess.run(tf.shape(t3)))
    sess.close()

2.Session对象

Session类负责数据流图的执行,tf.Session()可接收3种参数,graph参数可传入要加载的Graph对象。
调用tf.Session()创建Session对象:
sess=tf.Session()
sess.run()方法可以接收:
fetches参数:如果执行对象是Tensor对象,run()方法输出为Numpy数组;如果对象为Op(指向某个Op的句柄),则输出为None。
feed_dict参数:用于覆盖数据流图中的Tensor对象,输入为python字典形式{键:值},其中的键为数据流图中Tensor对象的句柄,值的类型必须与该Tensor对象的相同。

import tensorflow as tf
import numpy as np
g=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
    a=tf.constant([5,6],name="a")
    b=tf.add(a,[2,3])
    replace_dict={a:[2,2]}
    print(sess.run(b,feed_dict=replace_dict))

Session对象使用结束时要调用其close()方法将不再需要的资源释放。也可以将Session()对象作为上下文管理器使用,这样当代码离开其作用域后,该Session对象将自动关闭
with tf.Session() as sess:
#运行数据流图
…#该部分结束后Session()对象自动关闭
也可以利用Session类的as_default()方法将Session对象作为上下文管理器使用,可被某些函数自己使用。这些函数常见的有Operation.run(),Tensor.eval(),调用这些函数相当于将他们直接传入Session.run()函数,使用结束后需要手动关闭Session对象。

import tensorflow as tf
import numpy as np
g=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
    a=tf.constant([5,6],name="input_a")
    with sess.as_default():
        print(a.eval())
    sess.close()#必须手动关闭Session对象

3.占位节点

利用tf.placeholder()创建占位符,其中dtype属性是必须指明的,而shape()属性可选

import tensorflow as tf
import numpy as np
g=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
    a=tf.placeholder(tf.int32,shape=[2],name="input_a")#创建一个长度为2数据类型为int32的占位向量
    b=tf.placeholder(tf.int32,shape=None,name="input_b")#此处shape的值为None,表示可以接收任意形状的Tensor
    c=tf.add(a,b)
    with sess.as_default():
        input_dict={a:np.array([1,2],dtype=np.int32),b:np.array([3,4],dtype=np.int32)}
        print(sess.run(c,feed_dict=input_dict))
    sess.close()#必须手动关闭Session对象

4.Variable对象

Variable对象保存可变的张量值。为使用Variable对象,必须对其进行初始化,将所有Variable对象重置为初始值,调用tf.initialize_variables()。只对部分Variable对象重置初始化,调用tf.initialize_variables()参数传入一个要初始化的Variable对象列表。Variable对象通过Variable.assign_add()和Variable.assign_sub()方法进行自增和自减操作。Variable对象可以被多个独立不同的Session对象出事话,每个Session维护自己的Variable值。
TensorFlow提供大量辅助Op,如
tf.zeros([2,2])创建2×2的零矩阵
tf.ones([6])创建长度为6的全1向量
tf.random_uniform([2,2],minval=0,maxval=10)创建2×2的张量,元素服从0-10的均匀分布
tf.random_normal([2×2],means=0.0,stddev=2.0)创建2×2的张量,元素服从0均值标准差为2的正态分布
tf.truncated_normal([2,2],means=5.0,stddev=1.0)创建2×2张量,元素值在[3,7]之间(值不会超过2倍标准差)

import tensorflow as tf
import numpy as np
g=tf.get_default_graph()
sess=tf.Session(graph=g)
with g.as_default():
    a=tf.placeholder(tf.int32,shape=[2],name="input_a")
    b=tf.placeholder(tf.int32,shape=None,name="input_b")
    my_var1=tf.Variable(8,name="my_var")
    my_var2=tf.Variable(tf.ones([6]))
    my_var3=my_var1.assign(my_var1+2)
    my_var4=my_var2.assign(tf.zeros([6]))#修改的值必须与原值相同shape
    init=tf.initialize_all_variables()#初始化全部变量
    init2=tf.initialize_variables([my_var1])#初始化部分变量,变量以列表形式传入
    add1=tf.add(3,my_var1)
    c=tf.add(a,b)
    zeros=tf.zeros([2,2])
    with sess.as_default():
        input_dict={a:np.array([1,2],dtype=np.int32),b:np.array([3,4],dtype=np.int32)}
        sess.run(init)#Variable对象必须初始化
        print(sess.run(c,feed_dict=input_dict))
        print(sess.run(add1))
        print(sess.run(my_var3))#值为10
        print(sess.run(my_var3))#与my_var1值相同,为12
        print(sess.run(my_var4))
        print(sess.run(my_var2))
        print(sess.run(my_var1))
        sess.run(init2)
        print(sess.run(my_var1))
        print(sess.run(my_var1.assign_add(1)))#自增1
        print(sess.run(my_var2))
        print(sess.run(my_var3))

    sess.close()#必须手动关闭Session对象
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值