使用tf.meshgrid()和一些重塑的解决方案:
import tensorflow as tf
import numpy as np
t = tf.placeholder(tf.int32, [None, 2])
num_rows, size_row = tf.shape(t)[0], tf.shape(t)[1] # actual dynamic dimensions
# Getting pair indices using tf.meshgrid:
idx_range = tf.range(num_rows)
pair_indices = tf.stack(tf.meshgrid(*[idx_range, idx_range]))
pair_indices = tf.transpose(pair_indices, perm=[1, 2, 0])
# Finally gathering the rows accordingly:
res = tf.reshape(tf.gather(t, pair_indices), (-1, size_row * 2))
with tf.Session() as sess:
print(sess.run(res, feed_dict={t: np.array([[1,2], [3,4], [5,6]])}))
# [[1 2 1 2]
# [3 4 1 2]
# [5 6 1 2]
# [1 2 3 4]
# [3 4 3 4]
# [5 6 3 4]
# [1 2 5 6]
# [3 4 5 6]
# [5 6 5 6]]
使用笛卡尔积的解决方案:
import tensorflow as tf
import numpy as np
t = tf.placeholder(tf.int32, [None, 2])
num_rows, size_row = tf.shape(t)[0], tf.shape(t)[1] # actual dynamic dimensions
# Getting pair indices by computing the indices cartesian product:
row_idx = tf.range(num_rows)
row_idx_a = tf.expand_dims(tf.tile(tf.expand_dims(row_idx, 1), [1, num_rows]), 2)
row_idx_b = tf.expand_dims(tf.tile(tf.expand_dims(row_idx, 0), [num_rows, 1]), 2)
pair_indices = tf.concat([row_idx_a, row_idx_b], axis=2)
# Finally gathering the rows accordingly:
res = tf.reshape(tf.gather(t, pair_indices), (-1, size_row * 2))
with tf.Session() as sess:
print(sess.run(res, feed_dict={t: np.array([[1,2], [3,4], [5,6]])}))
# [[1 2 1 2]
# [1 2 3 4]
# [1 2 5 6]
# [3 4 1 2]
# [3 4 3 4]
# [3 4 5 6]
# [5 6 1 2]
# [5 6 3 4]
# [5 6 5 6]]