tf.tile()函数是用来对张量(Tensor)进行扩展,作用是对当前张量内的数据进行一定规则的复制。
函数定义为:
tf.tile(
input,
multiples,
name=None
)
input是待扩展的张量,multiples是扩展方法。
假如input是一个3维(比如:(5, 6, 7)。)的张量。那么mutiples就必须是一个1*3的1维张量。这个张量的三个值依次表示input的第1,第2,第3维数据扩展几倍。
例子:
import tensorflow as tf
a = tf.constant([[1, 2], [3, 4], [5, 6]], dtype=tf.float32)
a1 = tf.tile(a, [2, 3])
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(a1))
结果:
# (3, 2),为2维
[[1. 2.]
[3. 4.]
[5. 6.]]
# (6, 6),仍为2维
[[1. 2. 1. 2. 1. 2.]
[3. 4. 3. 4. 3. 4.]
[5. 6. 5. 6. 5. 6.]
[1. 2. 1. 2. 1. 2.]
[3. 4. 3. 4. 3. 4.]
[5. 6. 5. 6. 5. 6.]]
tf.tile()具体的操作过程如下:
请注意:上面绘图中第一次扩展后第一维由三个数据变成两行六个数据,多一行并不是多了一维,数据扔为顺序排列,只是为了方便绘制而已。
tf.tile保持张量原维度不变,改变每个维度的元素个数;tf.expand_dims改变张量的维度(详见:https://blog.csdn.net/duanlianvip/article/details/96448393)。
每一维数据的扩展都是将前面的数据进行复制然后直接接在原数据后面。
如果multiples的某一个数据为1,则表示该维数据保持不变。