tensorflow常用函数整理
记录自己使用过的tensorflow常用函数,以便以后查询
1 tf.reshape
tf.reshape(
tensor,
shape,
name=None
)
将给定tensor的维度转换为shape。
arr = tf.Variable([[1,2], [3,4], [5,6]])
arr = tf.reshape(arr, [2, 3])
[[1 2 3]
[4 5 6]]
当缺失的维度为-1时,会根据给定的维度自动计算缺失的维度; 但是缺失的维度只能有一个。
arr = tf.reshape(arr, [3, -1])
[[1 2]
[3 4]
[5 6]]
2 tf.concat()
tf.concat(
values,
axis,
name='concat'
)
将输入的张量数据沿着axis维度连接,如果输入数据的维度分别为(2,3), (2,3), axis=0时将第0维的2和2加起来,第1维的两个3不变,连接起来的tensor的shape为(4, 3), 同理axis=1时连接起来的shape为(2,6)。(python中维度的索引从0开始计算,正axis取值范围为[0, rank(values) ,这里也就是[0, 2))
t1 = tf.Variable([[1, 2, 3], [4, 5, 6]])
t2 = tf.Variable([[7, 8, 9], [10, 11, 12]])
t3 = tf.concat([t1, t2], 0)
[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]]
t4 = tf.concat([t1, t2], 1)
[[ 1 2 3 7 8 9]
[ 4 5 6 10 11 12]]
在Python中,axis可以为负值,解释为从rank的末尾开始计数,即
axis + rank(values)
t5 = tf.concat([t1, t2], -1)
[[ 1 2 3 7 8 9]
[ 4 5 6 10 11 12]]
3.tf.expand_dims
tf.expand_dims(
input,
axis=None,
name=None,
dim=None
)
这个操作可以用来给单个元素添加batch维度,例如,如果你有一张维度为[height, width, channels]的图片,可以用expand_dims(image, 0)使它成为batch为1的图片,shape将变为
[1, height, width, channels]。
t1 = tf.Variable([[1, 2, 3], [4, 5, 6]])
t_expand = tf.expand_dims(t1, 0)
[[[1 2 3]
[4 5 6]]]
shape:(1, 2, 3)