改变Tensor的维度
我想更改Tensor的维度,我发现了3种方法来实现这一点,如下所示
a = tf.constant([[1,2,3],[4,5,6]]) # shape (2,3)
# change dimention of a to (2,3,1)
b = tf.expand_dims(a,2) # shape(2,3,1)
c = a[:,:,tf.newaxis] # shape(2,3,1)
d = tf.reshape(a,(2,3,1)) # shape(2,3,1)
不同点
tf.expand_dims
tf.expand_dims(
input,
axis,
name=None
) #添加一个维度并且它的索引是可变的
t = [[1, 2, 3],[4, 5, 6]] # shape [2, 3]
tf.expand_dims(t, 0) #shape=(1, 2, 3)
tf.expand_dims(t, 1) #shape=(2, 1, 3)
tf.expand_dims(t, 2) #shape=(2, 3, 1)
tf.expand_dims(t, -1) #shape=(2, 3, 1)
上例t的类型为list tf.expand_dims后类型变为tf.Tensor
tf.newaxis
a[:,:,tf.newaxis] #可读性强 添加多个尺寸特别方便(而不是多次调用tf.expand_dims)
为什么可读性强强? 例如
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo.shape)
print(foo[tf.newaxis, :, :]) # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[:, tf.newaxis, :]) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]]
print(foo[:, :, tf.newaxis]) # =>[[[1],[2],[3]], [[4],[5],[6]],[[7],[8],[9]]]
# 输出的4个的维度依次(3, 3) (1, 3, 3) (3, 1, 3) (3, 3, 1)
# tf.newaxis表示添加的一个维度 :代表原来的维度 再比如
print(foo[:, :, tf.newaxis, tf.newaxis, tf.newaxis])#shape=(3, 3, 1, 1, 1)
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo)
print(foo[tf.newaxis, :, :]) # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[tf.newaxis, ...]) # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[tf.newaxis]) # => [[[1,2,3], [4,5,6], [7,8,9]]]
# shape=(3, 3) (1, 3, 3) (1, 3, 3) (1, 3, 3)
# ...代表所以的维度
Tensor操作
去除最前和最后的2个元素
foo = tf.constant([1,2,3,4,5,6])
print(foo[2:-2]) # => [3,4]
跳过每隔一行并反转列的顺序
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[::2,::-1]) # => [[3,2,1], [9,8,7]]
将标量张量用作两个维度的索引
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]]) #shape=(2, 3)
print(foo[tf.constant(0), tf.constant(2)]) # => 3 (第0维的第3个元素)
Masks
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo)
print("*"*50) #分割线
print([foo > 2])
print("*"*50) #分割线
print(foo[foo > 2])
print("*"*50) #分割线
输出
tf.Tensor(
[[1 2 3]
[4 5 6]
[7 8 9]], shape=(3, 3), dtype=int32)
**************************************************
[<tf.Tensor: id=387, shape=(3, 3), dtype=bool, numpy=
array([[False, False, True],
[ True, True, True],
[ True, True, True]])>]
**************************************************
tf.Tensor([3 4 5 6 7 8 9], shape=(7,), dtype=int32)
**************************************************