写在前边的话:之前发布过一个关于TensorFlow1.x的转载系列,自己将基本的TensorFlow操作敲了一遍,但是仍然有很多地方理解的不够深入。所以重开一个系列,跟着网上找到的教程边听边再敲一遍。最终的目的是实现一个新闻分类的demo,代码已有但是之前没有看懂。再往后应该会出一个pytorch的系列,最后目的是将tensorflow1.x的代码用pytorch再实现一遍。
1.基本介绍
tensorflow1.x 有其独有的语法体系,不同于Python代码的是,自定义的变量和函数无法直接输出结果,必须要在会话中完成该操作。
import tensorflow as tf
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:521: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
# 实现一个加法运算
a = tf.constant(5.0)
b = tf.constant(6.0)
sum1 = tf.add(a,b)
# 查看结果
print(a,b)
print(sum1)
Tensor("Const:0", shape=(), dtype=float32) Tensor("Const_1:0", shape=(), dtype=float32)
Tensor("Add:0", shape=(), dtype=float32)
# 在会话中实现
with tf.Session() as sess:
print(sess.run(sum1))
11.0
2.图的结构
一些概念:
tensor:张量
operation(op):专门运算的操作节点(注:所有操作都是一个op)
图(graph):你的整个程序的结构
会话(Session):运算程序的图
# 查看默认的图,根据打印的结果可以看到,图的在内存中的位置
graph = tf.get_default_graph()
graph
<tensorflow.python.framework.ops.Graph at 0x1c82a43df60>
# 打印各op的graph属性可以看到,地址都一样,都是该图的一部分
with tf.Session() as sess:
print(sess.run(sum1))
print(a.graph)
print(sum1.graph)
print(sess.graph)
11.0
<tensorflow.python.framework.ops.Graph object at 0x000001C82A43DF60>
<tensorflow.python.framework.ops.Graph object at 0x000001C82A43DF60>
<tensorflow.python.framework.ops.Graph object at 0x000001C82A43DF60>
# 一般一个程序里只有一个图,那么如何定义其他图呢
g = tf.Graph()
print(g)
<tensorflow.python.framework.ops.Graph object at 0x000001C82A48E588>
# 可以看到新定义的图和之前的图内存位置并不同
with g.as_default():
c = tf.constant(11.0)
print(c)
print(c.graph)
Tensor("Const:0", shape=(), dtype=float32)
<tensorflow.python.framework.ops.Graph object at 0x000001C82A48E588>
总结:
创建一张图包含了一组op和tensor,上下文环境
op:只要使用tensorflow的API定义的函数都是OP
tensor:指代的就是数据
3.会话,会话的run方法
理解:
可以将tensorflow看做两部分:
前端系统:定义程序的图结构
后端系统:运算图结构
会话的作用:
1.运行图的结构
2.分配资源计算
3.掌握资源(变量的资源,队列,线程)
一次只能运行一个图,可以在会话中指定图去运行tf.Session(graph = )
# 一次只能运行一个图,可以在会话中指定图去运行tf.Session(graph = )
with tf.Session() as sess:
print(sess.run(c))
print(a.graph)
print(sum1.graph)
print(sess.graph)
# 会报错,因为c是别的图的一部分
# ValueError: Fetch argument <tf.Tensor 'Const:0' shape=() dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("Const:0", shape=(), dtype=float32) is not an element of this graph.)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in __init__(self, fetches, contraction_fn)
281 self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 282 fetch, allow_tensor=True, allow_operation=True))
283 except TypeError as e:
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
3458 with self._lock:
-> 3459 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3460
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
3537 if obj.graph is not self:
-> 3538 raise ValueError("Tensor %s is not an element of this graph." % obj)
3539 return obj
ValueError: Tensor Tensor("Const:0", shape=(), dtype=float32) is not an element of this graph.
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-10-c64ee74128b9> in <module>
1 # 一次只能运行一个图
2 with tf.Session() as sess:
----> 3 print(sess.run(c))
4 print(a.graph)
5 print(sum1.graph)
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
903 try:
904 result = self._run(None, fetches, feed_dict, options_ptr,
--> 905 run_metadata_ptr)
906 if run_metadata:
907 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1120 # Create a fetch handler to take care of the structure of fetches.
1121 fetch_handler = _FetchHandler(
-> 1122 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1123
1124 # Run request and get response.
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in __init__(self, graph, fetches, feeds, feed_handles)
425 """
426 with graph.as_default():
--> 427 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
428 self._fetches = []
429 self._targets = []
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in for_fetch(fetch)
251 if isinstance(fetch, tensor_type):
252 fetches, contraction_fn = fetch_fn(fetch)
--> 253 return _ElementFetchMapper(fetches, contraction_fn)
254 # Did not find anything.
255 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in __init__(self, fetches, contraction_fn)
287 except ValueError as e:
288 raise ValueError('Fetch argument %r cannot be interpreted as a '
--> 289 'Tensor. (%s)' % (fetch, str(e)))
290 except KeyError as e:
291 raise ValueError('Fetch argument %r cannot be interpreted as a '
ValueError: Fetch argument <tf.Tensor 'Const:0' shape=() dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("Const:0", shape=(), dtype=float32) is not an element of this graph.)
指定图运行
# 指定图运行
with tf.Session(graph = g) as sess:
print(sess.run(c))
print(a.graph) # 这个能正常输出是因为在会话外就定义过
print(sum1.graph) # 同上
print(sess.graph)
11.0
<tensorflow.python.framework.ops.Graph object at 0x000001C82A43DF60>
<tensorflow.python.framework.ops.Graph object at 0x000001C82A43DF60>
<tensorflow.python.framework.ops.Graph object at 0x000001C82A48E588>
关于sess.run(fetches,feed_dict=None,graph = None):
1.作用:运行ops和计算tensor,相当于是启动整个图
2.sess.close():与sess.run()相对应,关闭资源。但是在使用上下文管理器的结构中(with tf.Session as sess:...)可以省略
其他:
3.tf.Session() 中还有另外一个参数config (tf.Session(config = tf.ConfigPorto(log_device_placement=True)))作用是显示你的op具体是在那个设备上运行的以及其他详细情况。
4.交互式:在命令行里使用,tf.interactiveSession(),方便调试,结合变量.eval()比较方便
例1:一次run多个
# 例1:一次run多个
with tf.Session() as sess:
print(sess.run([a,b,sum1]))
[5.0, 6.0, 11.0]
2-1:run只能运行op和tensor
# 例2-1:run只能运行op和tensor
var1 = 2
var2 = 3
sum2 = var1 + var2
with tf.Session() as sess:
print(sess.run(sum2))
# TypeError: Fetch argument 5 has invalid type <class 'int'>, must be a string or Tensor. (Can not convert a int into a Tensor or Operation.)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in __init__(self, fetches, contraction_fn)
281 self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 282 fetch, allow_tensor=True, allow_operation=True))
283 except TypeError as e:
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
3458 with self._lock:
-> 3459 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3460
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\framework\ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
3547 raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
-> 3548 types_str))
3549
TypeError: Can not convert a int into a Tensor or Operation.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
<ipython-input-13-31c14cd7002b> in <module>
4 sum2 = var1 + var2
5 with tf.Session() as sess:
----> 6 print(sess.run(sum2))
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
903 try:
904 result = self._run(None, fetches, feed_dict, options_ptr,
--> 905 run_metadata_ptr)
906 if run_metadata:
907 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1120 # Create a fetch handler to take care of the structure of fetches.
1121 fetch_handler = _FetchHandler(
-> 1122 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1123
1124 # Run request and get response.
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in __init__(self, graph, fetches, feeds, feed_handles)
425 """
426 with graph.as_default():
--> 427 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
428 self._fetches = []
429 self._targets = []
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in for_fetch(fetch)
251 if isinstance(fetch, tensor_type):
252 fetches, contraction_fn = fetch_fn(fetch)
--> 253 return _ElementFetchMapper(fetches, contraction_fn)
254 # Did not find anything.
255 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
C:\Anaconda\envs\tensorflow16\lib\site-packages\tensorflow\python\client\session.py in __init__(self, fetches, contraction_fn)
284 raise TypeError('Fetch argument %r has invalid type %r, '
285 'must be a string or Tensor. (%s)' %
--> 286 (fetch, type(fetch), str(e)))
287 except ValueError as e:
288 raise ValueError('Fetch argument %r cannot be interpreted as a '
TypeError: Fetch argument 5 has invalid type <class 'int'>, must be a string or Tensor. (Can not convert a int into a Tensor or Operation.)
例2-2 有重载的机制,默认会给运算符重载成op类型
# 例2-2 有重载的机制,默认会给运算符重载成op类型
var1 = 2.0
sum2 = a + var1
with tf.Session() as sess:
print(sess.run(sum2))
7.0
4.placeholder
应用场景:训练模型时需要实时提供数据去训练
介绍:
1.placeholder是一个占位符,使用中充当feed_dict字典中的键
2.参数:placeholder(dtype,shape = None,name = None)
例1 :placeholder的使用
# 例1 :placeholder的使用
plt = tf.placeholder(tf.float32,[2,3])
with tf.Session(config = tf.ConfigProto(log_device_placement = True)) as sess:
print(sess.run(plt,feed_dict = {plt:[[1,2,3],[4,5,6]]}))
[[1. 2. 3.]
[4. 5. 6.]]
例2 :placeholder的使用2(样本行数不固定)
# 例2 :placeholder的使用2(样本行数不固定)
plt = tf.placeholder(tf.float32,[None,3]) # n行3列
with tf.Session(config = tf.ConfigProto(log_device_placement = True)) as sess:
print(sess.run(plt,feed_dict = {plt:[[1,2,3],[4,5,6],[7,8,9]]}))
[[1. 2. 3.]
[4. 5. 6.]
[7. 8. 9.]]
5.返回值异常
RuntimeError: 如果它Session处于无效状态(例如已关闭)。
TypeError:如果fetches或feed_dict键是不合适的类型。
ValueError:如果fetches或feed_dict键无效或引用tensor不存在