API: https://tensorflow.google.cn/api_docs/python/tf/expand_dims?hl=zh-cn
tf.expand_dims(
input,
axis=None,
name=None,
dim=None
)
在input的axis位置插入一维的张量
这个操作在input的维度中索引为axis
的位置插入一维张量。维度索引axis
从零开始; 如果指定负数,axis
则从末尾向后计数
例子:
# 't' is a tensor of shape [2]
tf.shape(tf.expand_dims(t, 0)) # [1, 2]
tf.shape(tf.expand_dims(t, 1)) # [2, 1]
tf.shape(tf.expand_dims(t, -1)) # [2, 1]
# 't2' is a tensor of shape [2, 3, 5]
tf.shape(tf.expand_dims(t2, 0)) # [1, 2, 3, 5]
tf.shape(tf.expand_dims(t2, 2)) # [2, 3, 1, 5]
tf.shape(tf.expand_dims(t2, 3)) # [2, 3, 5, 1]
这个操作在一个batch里面插入一个一维元素是很好用的