张量的变换
功能 代码 改变视图 tf.reshape(x,shape) 增加维度 tf.expand_dims(x, axis) 删除维度 tf.squeeze(x, axis) 交换维度 tf.transpose(x, perm) 复制数据 tf.tile(x, multiples)
改变视图
tf.reshape(x,shape)形成新形状 视图变换只需要满足新视图的元素总量与存储区域大小相等即可 x.ndim可以查询维度,x.shape查询形状
import tensorflow as tf
x = tf. range ( 96 )
y = tf. reshape( x, [ 2 , 4 , 4 , 3 ] )
print ( x, '\n' , y)
out:
tf. Tensor(
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 ] , shape= ( 96 , ) , dtype= int32)
tf. Tensor(
[ [ [ [ 0 1 2 ]
[ 3 4 5 ]
[ 6 7 8 ]
[ 9 10 11 ] ]
[ [ 12 13 14 ]
[ 15 16 17 ]
[ 18 19 20 ]
[ 21 22 23 ] ]
[ [ 24 25 26 ]
[ 27 28 29 ]
[ 30 31 32 ]
[ 33 34 35 ] ]
[ [ 36 37 38 ]
[ 39 40 41 ]
[ 42 43 44 ]
[ 45 46 47 ] ] ]
[ [ [ 48 49 50 ]
[ 51 52 53 ]
[ 54 55 56 ]
[ 57 58 59 ] ]
[ [ 60 61 62 ]
[ 63 64 65 ]
[ 66 67 68 ]
[ 69 70 71 ] ]
[ [ 72 73 74 ]
[ 75 76 77 ]
[ 78 79 80 ]
[ 81 82 83 ] ]
[ [ 84 85 86 ]
[ 87 88 89 ]
[ 90 91 92 ]
[ 93 94 95 ] ] ] ] , shape= ( 2 , 4 , 4 , 3 ) , dtype= int32)
import tensorflow as tf
x = tf. range ( 96 )
y = tf. reshape( x, [ 2 , 4 , 4 , - 1 ] )
print ( x, '\n' , y)
out:
tf. Tensor(
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 ] , shape= ( 96 , ) , dtype= int32)
tf. Tensor(
[ [ [ [ 0 1 2 ]
[ 3 4 5 ]
[ 6 7 8 ]
[ 9 10 11 ] ]
[ [ 12 13 14 ]
[ 15 16 17 ]
[ 18 19 20 ]
[ 21 22 23 ] ]
[ [ 24 25 26 ]
[ 27 28 29 ]
[ 30 31 32 ]
[ 33 34 35 ] ]
[ [ 36 37 38 ]
[ 39 40 41 ]
[ 42 43 44 ]
[ 45 46 47 ] ] ]
[ [ [ 48 49 50 ]
[ 51 52 53 ]
[ 54 55 56 ]
[ 57 58 59 ] ]
[ [ 60 61 62 ]
[ 63 64 65 ]
[ 66 67 68 ]
[ 69 70 71 ] ]
[ [ 72 73 74 ]
[ 75 76 77 ]
[ 78 79 80 ]
[ 81 82 83 ] ]
[ [ 84 85 86 ]
[ 87 88 89 ]
[ 90 91 92 ]
[ 93 94 95 ] ] ] ] , shape= ( 2 , 4 , 4 , 3 ) , dtype= int32)
import tensorflow as tf
x = tf. range ( 96 )
y = tf. reshape( x, [ 2 , 4 , 4 , - 1 ] )
print ( y. ndim)
print ( y. shape)
out:
4
( 2 , 4 , 4 , 3 )
增、 删维度
tf.expand_dims(x, axis)可以插入一个指定的新维度(类似指针) axis 可为正负 tf.squeeze(x, axis)函数, axis 参数为待删除的维度的索引号,缺省时删除所有为1的维度
import tensorflow as tf
x = tf. range ( 96 )
y = tf. reshape( x, [ 2 , 4 , 4 , 3 ] )
r = tf. expand_dims( y, axis = 4 )
r1 = tf. expand_dims( y, axis= - 4 )
print ( r. shape)
print ( r1. shape)
out:
( 2 , 4 , 4 , 3 , 1 )
( 2 , 1 , 4 , 4 , 3 )
import tensorflow as tf
x = tf. range ( 96 )
y = tf. reshape( x, [ 2 , 4 , 4 , 3 ] )
r = tf. expand_dims( y, axis = 4 )
r1 = tf. expand_dims( y, axis= - 4 )
print ( r. shape)
print ( r1. shape)
r2 = tf. squeeze( r, axis = 4 )
r3 = tf. squeeze( r1)
print ( r2. shape)
print ( r3. shape)
out:
( 2 , 4 , 4 , 3 , 1 )
( 2 , 1 , 4 , 4 , 3 )
( 2 , 4 , 4 , 3 )
( 2 , 4 , 4 , 3 )
交换维度
tf.transpose(x, perm) ,perm为列表形式的新维数 以[b, h,w, c]转换到[b, c, h,w]为例:tf.transpose(x, perm),perm=[0,3,1,2]
import tensorflow as tf
x = tf. range ( 27 )
y = tf. reshape( x, [ 3 , 3 , 3 ] )
y1 = tf. transpose( y, [ 0 , 2 , 1 ] )
print ( y, '\n' , y1)
out:
tf. Tensor(
[ [ [ 0 1 2 ]
[ 3 4 5 ]
[ 6 7 8 ] ]
[ [ 9 10 11 ]
[ 12 13 14 ]
[ 15 16 17 ] ]
[ [ 18 19 20 ]
[ 21 22 23 ]
[ 24 25 26 ] ] ] , shape= ( 3 , 3 , 3 ) , dtype= int32)
tf. Tensor(
[ [ [ 0 3 6 ]
[ 1 4 7 ]
[ 2 5 8 ] ]
[ [ 9 12 15 ]
[ 10 13 16 ]
[ 11 14 17 ] ]
[ [ 18 21 24 ]
[ 19 22 25 ]
[ 20 23 26 ] ] ] , shape= ( 3 , 3 , 3 ) , dtype= int32)
复制数据
tf.tile(x, multiples) 以x作为单位基准方块,multiples指定单位基准方块的拓展,multiples = [2,2]时如下
import tensorflow as tf
a = tf. range ( 9 )
x = tf. reshape( a, [ 3 , 3 ] )
b = tf. tile( x, multiples= [ 2 , 2 ] )
b1 = tf. tile( x, multiples= [ 2 , 1 ] )
print ( x, '\n' , b, '\n' , b1)
out:
tf. Tensor(
[ [ 0 1 2 ]
[ 3 4 5 ]
[ 6 7 8 ] ] , shape= ( 3 , 3 ) , dtype= int32)
tf. Tensor(
[ [ 0 1 2 0 1 2 ]
[ 3 4 5 3 4 5 ]
[ 6 7 8 6 7 8 ]
[ 0 1 2 0 1 2 ]
[ 3 4 5 3 4 5 ]
[ 6 7 8 6 7 8 ] ] , shape= ( 6 , 6 ) , dtype= int32)
tf. Tensor(
[ [ 0 1 2 ]
[ 3 4 5 ]
[ 6 7 8 ]
[ 0 1 2 ]
[ 3 4 5 ]
[ 6 7 8 ] ] , shape= ( 6 , 3 ) , dtype= int32)
Broadcasting
tf.broadcast_to(x, new_shape) Broadcasting和 tf.tile 复制的最终效果是一样的,操作对用户透明,但是 Broadcasting 机制节省了大量计算资源,建议在运算过程中尽可能地利用 Broadcasting 机制提高计算效率