tf.expand_dims()和tf.tile()
tf.expand_dims() 修改形状,数据不变
tf.tile()
import tensorflow as tf
x1 = tf.constant(value=[1, 2, 3], dtype=tf.float32)
x2 = tf.constant(value=[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]
], dtype=tf.float32)
# 修改形状,数据不变
y1 = tf.expand_dims(x1, axis=0)
y21 = tf.expand_dims(x2, axis=0)
y22 = tf.expand_dims(x2, axis=1)
y23 = tf.expand_dims(x2, axis=2)
with tf.Session() as sess:
print(sess.run(x1))
print(sess.run(y1))
print(sess.run(x2))
print(sess.run(y21))
print(sess.run(y22))
print(sess.run(y23))
'''
[1. 2. 3.]
[[1. 2. 3.]]
[[ 1. 2. 3.]
[ 4. 5. 6.]
[ 7. 8. 9.]
[10. 11. 12.]]
[[[ 1. 2. 3.]
[ 4. 5. 6.]
[ 7. 8. 9.]
[10. 11. 12.]]]
[[[ 1. 2. 3.]]
[[ 4. 5. 6.]]
[[ 7. 8. 9.]]
[[10. 11. 12.]]]
[[[ 1.]
[ 2.]
[ 3.]]
[[ 4.]
[ 5.]
[ 6.]]
[[ 7.]
[ 8.]
[ 9.]]
[[10.]
[11.]
[12.]]]
'''
x3 = tf.constant(value=[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]
], dtype=tf.float32)
# 数据重复
y31 = tf.expand_dims(x3, axis=1)
y32 = tf.tile(y31, [1, 2, 1]) # 基于第二维重复
y33 = tf.tile(y31, [1, 1, 2]) # 基于第三维重复
# y21 = tf.expand_dims(x2, axis=0)
# y22 = tf.expand_dims(x2, axis=1)
# y23 = tf.expand_dims(x2, axis=2)
with tf.Session() as sess:
print(sess.run(x3))
print(sess.run(y31))
print(sess.run(y32))
print(sess.run(y33))
'''
[[ 1. 2. 3.]
[ 4. 5. 6.]
[ 7. 8. 9.]
[10. 11. 12.]]
[[[ 1. 2. 3.]]
[[ 4. 5. 6.]]
[[ 7. 8. 9.]]
[[10. 11. 12.]]]
[[[ 1. 2. 3.]
[ 1. 2. 3.]]
[[ 4. 5. 6.]
[ 4. 5. 6.]]
[[ 7. 8. 9.]
[ 7. 8. 9.]]
[[10. 11. 12.]
[10. 11. 12.]]]
[[[ 1. 2. 3. 1. 2. 3.]]
[[ 4. 5. 6. 4. 5. 6.]]
[[ 7. 8. 9. 7. 8. 9.]]
[[10. 11. 12. 10. 11. 12.]]]
'''