Tensorflow常用函数记录
1. tf.expand_dims
在一个Tensor的shape中插入一个维度。
tf.expand_dims(
input,
axis=None,
name=None,
dim=None deprecated(等价于axis)
)
使用情况:
对于一个Tensor元素,想给它加一个batch的维度。
例如:有一个图片[height, width, channels]。
通过expand_dims(image, 0)
,将其变成一个图片的batch。
数据变成[1, height, width, channels]
其他例子
# 't2' is a tensor of shape [2, 3, 5]
tf.shape(tf.expand_dims(t2, 0)) # [1, 2, 3, 5]
tf.shape(tf.expand_dims(t2, 2)) # [2, 3, 1, 5]
tf.shape(tf.expand_dims(t2, 3)) # [2, 3, 5, 1]
2. tf.squeeze
删除一个Tensor中所有大小为1的维度。只是在shape上做了改动。
可以指定删除某些大小为1的维度。
tf.squeeze(
input,
axis=None,
name=None,
squeeze_dims=None deprecated(等价于axis)
)
例如:
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t)) # [2, 3]
或者删除某个指定大小为1的维度。
# 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
tf.shape(tf.squeeze(t, [2, 4])) # [1, 2, 3, 1]
3. tf.rank
返回一个tensor的rank(tensor的shape矩阵的大小,类似维度数)
tf.rank(
input,
name=None
)
# shape of tensor 't' is [2, 2, 3]
t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
tf.rank(t) # 3
4. tf.stack
将一个rank-R tensors的列表变成一个rank-(R+1)的tensor。
假设列表大小为N,其中的tensors shape=(A,B,C)
如果axis0, 输出tensor shape=(N,A,B,C)
如果axis1, 输出tensor shape=(A,N,B,C)
tf.stack(
values,
axis=0,
name='stack'
)
返回一个tensor
例如:
x = tf.constant([1, 4])
y = tf.constant([2, 5])
z = tf.constant([3, 6])
tf.stack([x, y, z]) # [[1, 4], [2, 5], [3, 6]] (Pack along first dim.)
tf.stack([x, y, z], axis=1) # [[1, 2, 3], [4, 5, 6]]
5. tf.slice
从tensor input中从begin位置开始提取出一个大小为size的切片。
begin和size都可以是数组。
begin[i]是input的第i个维度所想切片的偏移量。
size[i]是input的第i个维度所想切片的元素数量。
begin从0开始,size从1开始。
size[i] = input.dim_size(i) - begin[i]
tf.slice(
input_,
begin,
size,
name=None
)
返回一个tensor
例如:
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
tf.slice(t, [1, 0, 0], [1, 1, 3]) # [[[3, 3, 3]]]
tf.slice(t, [1, 0, 0], [1, 2, 3]) # [[[3, 3, 3],
# [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [2, 1, 3]) # [[[3, 3, 3]],
# [[5, 5, 5]]]
6. tf.split
将一个tensor切分多个小tensor
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
(1) 如果num_or_size_splits是integer类型
在axis的维度上将value切分成num_splits个小tensor。
(2) 如果num_or_size_splits不是integer类型
在axis的维度上将value切分成size_splits个小tensor。
size_splits是一个list,表示这个axis上的维度如何分配。
# 'value' is a tensor with shape [5, 30]
# Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0) # [5, 4]
tf.shape(split1) # [5, 15]
tf.shape(split2) # [5, 11]
# Split 'value' into 3 tensors along dimension 1
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0) # [5, 10]
7. tf.concat
在一个维度axis上,将tensors组装起来。
tf.concat(
values,
axis,
name='concat'
)
values: list of tensors
将tensors的数据在axis维度上连接。
输入的tensors的维度数量必须匹配,除了axis维度,其他维度必须大小相同。
原文:
That is, the data from the input tensors is joined along the axis dimension.
The number of dimensions of the input tensors must match, and all dimensions except axis must be equal.
例如:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) # [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) # [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
# tensor t3 with shape [2, 3]
# tensor t4 with shape [2, 3]
tf.shape(tf.concat([t3, t4], 0)) # [4, 3]
tf.shape(tf.concat([t3, t4], 1)) # [2, 6]