###1、concat
tf.concat相当于numpy中的np.concatenate函数,用于将两个张量在某一个维度(axis)合并起来
a=tf.reshape(np.arange(4),(2,2))
[[0 1]
[2 3]]
b=tf.reshape(np.arange(4,8),(2,2))
[[4 5]
[6 7]]
c=tf.concat([a,b],axis=0)
[[0 1]
[2 3]
[4 5]
[6 7]]
d=tf.concat([a,b],axis=1)
[[0 1 4 5]
[2 3 6 7]]
###2、stack
tf.concat拼接的是两个shape完全相同的张量,并且产生的张量的维度不会发生变化,而tf.stack拼接后的张量的维度+1
tf.stack 的axis 值取值范围为$ -[R+1]~(R+1)$
Given a list of length N
of tensors of shape (A, B, C)
;
if axis == 0
then the output
tensor will have the shape (N, A, B, C)
.
if axis == 1
then the output
tensor will have the shape (A, N, B, C)
.
if axis == -1(逆序)
then the output
tensor will have the shape (A, B, C, N)
.
一维tensor
x = tf.constant([1, 4])
y = tf.constant([2, 5])
z = tf.constant([3, 6])
# R=1,axis取值为 0,1 ||| -1,-2 ;axis=1和axis=-1结果相同
tf.stack([x, y, z]) # [[1, 4], [2, 5], [3, 6]]
tf.stack([x, y, z], axis=1) # [[1, 2, 3], [4, 5, 6]]
二维tensor
a=tf.reshape(np.arange(4),(2,2))
b=tf.reshape(np.arange(4,8),(2,2))
# R=2,axis取值为 0,1,2 ||| -1,-2,-3
#axis默认为0
c=tf.stack([a,b])
[[[0 1]
[2 3]]
[[4 5]
[6 7]]]
# axis=1和axis=-2结果相同
d=tf.stack([a,b],axis=1)
[[[0 1]
[4 5]]
[[2 3]
[6 7]]]
# axis=2和axis=-1结果相同
e=tf.stack([a,b],axis=2)
[[[0 4]
[1 5]]
[[2 6]
[3 7]]]
###3、transpose
二维tensor
x = tf.constant([[1, 2, 3], [4, 5, 6]])
tf.transpose(x) # [[1, 4]
# [2, 5]
# [3, 6]]
# Equivalently,将x的0轴和1轴的数据交换
tf.transpose(x, [1, 0]) # [[1, 4]
# [2, 5]
# [3, 6]]
三维tensor
x = tf.constant([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
# 将x的1轴和2轴交换
tf.transpose(x, [0, 2, 1]) # [[[1, 4],
# [2, 5],
# [3, 6]],
#
# [[7, 10],
# [8, 11],
# [9, 12]]]
详细解释可参考:https://blog.csdn.net/u012762410/article/details/78912667
###4、stack和transpose(三维输入)
下面将介绍transpose时的几个常用操作。
我们可以看到transpose高维变换并不直观,而transpose和stack(直观)某些情况下结果一致,从而可以很快得到transpose的结果
x=np.arange(12).reshape((2,2,3))
[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]]
#等价
print(np.stack(x,axis=0))
print(np.transpose(x,(0,1,2)))
[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]]
#等价
print(np.stack(x,axis=1))
print(np.transpose(x,(1,0,2)))
[[[ 0 1 2]
[ 6 7 8]]
[[ 3 4 5]
[ 9 10 11]]]
#等价
print(np.stack(x,axis=2))
print(np.transpose(x,(1,2,0)))
[[[ 0 6]
[ 1 7]
[ 2 8]]
[[ 3 9]
[ 4 10]
[ 5 11]]]