tf.reshape() tf.shape(x)与x.get_shape()
tf.rehspe()用法 : tf.reshape(tensor, shape, name=None)
>>> import tensorflow as tf
>>> import numpy as np
## 创建一个数组a
>>> a = np.arange(24)
>>> a
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23])
① reshape中shape默认用列表传入
>>> tf.reshape(a,[12,2])
<tf.Tensor 'Reshape:0' shape=(12, 2) dtype=int32>
>>> sess.run(tf.reshape(a,[4,6]))
array([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
② reshape中shape中的-1用法
## 行数固定4,列数默认计算
>>> sess.run(tf.reshape(a,[4,-1]))
array([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
## 列数固定4,行数默认计算
>>> sess.run(tf.reshape(a,[-1,4]))
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]])
③ reshape中shape中传入三个参数的意思
就是增加了一维,表示多个通道或者多张输入图片,如果是四个参数一般是[batch_size,channel,img_h,img_w].
>>> sess.run(tf.reshape(a,[2,3,4]))
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tf.shape(a)用法和a.get_shape()
① 两个函数都可以的到tensor的尺寸
② tf.shape(a)中的数据类型可以是tensor,array,list 但是a.get_shape 只能是tensor, 且返回值是元组
##创建一个数组a
>>> a = np.arange(24)
>>> a
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23])
## 从a的array中创建一个tensor b
>>> b = tf.reshape(a,[4,6])
>>> b
<tf.Tensor 'Reshape_11:0' shape=(4, 6) dtype=int32>
>>> sess.run(b)
array([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]])
## 创建一个列表c
>>> c = [1,2,3]
>>> c
[1, 2, 3]
我们来看tf.shape(x)函数作用a,b,c的输出:
>>> sess.run(tf.shape(a))
array([24])
>>> sess.run(tf.shape(b))
array([4, 6])
>>> sess.run(tf.shape(c))
array([3])
我们来看x.get_shape()函数作用a,b,c的输出:
>>> a.get_shape()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'
>>> b.get_shape()
TensorShape([Dimension(4), Dimension(6)])
>>> c.get_shape()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'list' object has no attribute 'get_shape'
说明x.get_shape()中,x只能是tensor,否则报错,并且返回的是一个元组,可以分别取出行和列:
>>> b.get_shape()
TensorShape([Dimension(4), Dimension(6)])
>>> print(b.get_shape())
(4, 6)
>>> print( b.get_shape()[0])
4
>>> print( b.get_shape()[1])
6
>>>print(b.get_shape().as_list())
>>> b.get_shape()[0].value
4
>>> b.get_shape()[1].value
6
参考: https://blog.csdn.net/fireflychh/article/details/73611021
[4, 6][4, 6]
上面都是别人的代码和说明。最后自我理解就是:
假如一个矩阵是三维的,维度2*4*3,
[[[a111 a112 a113],[a121 a122 a123],[a131 a132 a133],[a141 a142 a143]],[[a211 a212 a213],[a221 a222 a223],~~~]]]
矩阵后面的省略,我的意思就是用 下标表示这个矩阵。维度与我下面的代码矩阵一样。
1 tf.transpose(b,[1,0,2]) #表示要调整b矩阵的维度,从2*4*3变成4*2*3,0\1\2表示矩阵原来的维度,原来是[0,1,2],按行读取,
现在的矩阵[[[a111 a112 a113],[a211 a212 a213]],[[a121 a122 a123],[a231 a232 a233]]~~]]]
把[a111 a112 a113]看成一个整体,那么就是[a11 a12 a13 a14
a21 a22 a23 a24]
第一个与第二个维度交换,也就是原来的a12变成了现在b21,新的矩阵是
[b11 b12
b21 b22
b31 b32
b41 b42]
不知道说清楚没有。可以看看二维矩阵转置。
tf.reshape(bt,[-1,3])的用法相关,就放在一起了。第一个参数表示要变形状的张量(二维的叫矩阵,一维的叫数组),第二个参数表示目标张量的形状(shape=[])。需保证目标张量的形状正确,比如原来的是3*3*4(=36)的,那么目标的形状就必须能被36整除,例如7*2,就会报错。当形状(shape)的参数第一个是-1时,表示排成一维数组,每个数a1是1*3的数组。
tf.split(dt,4,0)的用法相关,一样放一起。第一个参数表示要被分的张量,第二个参数表示均分成多少份,第三个参数表示在第一个维度。(维度默认是0,1,2...)
多维的还没试过
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 22 16:03:45 2018
@author: ZZL
"""
import tensorflow as tf
import numpy as np
a=tf.constant([-1.0,2.0,3.0,4.0])
b=[[[1,2,3] ,[4,5,6],[7,8,9],[10,11,12]] ,[[13,14,15],[16,17,18] ,[19,20,21],[22,23,24]]]
#XT = tf.transpose(b, [1, 0, 2])
bt=tf.transpose(b, [1, 0, 2])
dt=tf.reshape(bt, [-1, 3])
et= tf.split(dt, 4, 0)#X_split =
f=np.argmax(b, axis=1)
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
],
[
[21, 6, -5, 2],
[9, 36, 2, 8],
[3, 7, 79, 1]
]
])
b=np.argmax(a, axis=0)#对于三维度矩阵,a有三个方向a[0][1][2]
#当axis=0时,是在a[0]方向上找最大值,即两个矩阵做比较,具体
#(1)比较3个矩阵的第一行,即拿[1, 5, 5, 2],
# [-1, 7, -5, 2],
# [21, 6, -5, 2],
#再比较每一列的最大值在那个矩阵中,可以看出第一列1,-2,21最大值为21,在第三个矩阵中,索引值为2
#第2列5,7,6最大值为7,在第2个矩阵中,索引值为1.....,最终得出比较结果[2 1 0 0]
#再拿出三个矩阵的第二行,按照上述方法,得出比较结果 [0 2 0 0]
#一共有三个,所以最终得到的结果b就为3行4列矩阵
print(b)
#[[2 1 0 0]
#[0 2 0 0]
#[1 0 2 0]]
with tf.Session() as sess:
m=bt
n=dt
et=et
print(sess.run(m))
print(sess.run(n))
print(sess.run(et))
print(sess.run(et)[1])
显示
runfile('C:/Users/ZZL/lainxi.py', wdir='C:/Users/ZZL')
[[2 1 0 0]
[0 2 0 0]
[1 0 2 0]]
[[[ 1 2 3]
[13 14 15]]
[[ 4 5 6]
[16 17 18]]
[[ 7 8 9]
[19 20 21]]
[[10 11 12]
[22 23 24]]]
[[ 1 2 3]
[13 14 15]
[ 4 5 6]
[16 17 18]
[ 7 8 9]
[19 20 21]
[10 11 12]
[22 23 24]]
[array([[ 1, 2, 3],
[13, 14, 15]]), array([[ 4, 5, 6],
[16, 17, 18]]), array([[ 7, 8, 9],
[19, 20, 21]]), array([[10, 11, 12],
[22, 23, 24]])]
[[ 4 5 6]
[16 17 18]]