看tensorflow源码时发现这个,记录一下:
函数定义: dynamic_partition(data, partitions, num_partition, name=None)
函数功能:在第一个维度上,将data数组,切分成num_partiton个数组,具体哪几个分在一起取决于partition的设置
参数说明: data:要用来做拆分的tensor ;
partitions:指定要将哪几项分在一起,分在一起的项用相同数字表示。如 : data = [ [1 , 2] , [3 , 4] , [5 , 6] , [7 , 8] ] ,partitions = [ 0,1,0,2 ],那么将 partitions[i] = 0的分在一起,partitions[i] = 1的分一起,partitions[i]=2的分一起 ,对应到原数组data,即[1,2]和[5,6]在一起作为一个,[3,4]单独一个,[7,8]单独一个。
num_partition:切分的个数,很明显对于partitions参数,partitions应当在[ 0 , num_partition )之间。如上述例子,值为3.
返回值: 切割好的数组,对应于上面例子,返回值应为: [ [1., 2.] , [ 5., 6.] ] , [ [ 3., 4. ] ] , [ [ 7., 8. ] ]
import tensorflow as tf
x2 = tf.constant([[1 , 2] , [3 , 4] , [5 , 6] , [7 , 8]], tf.float32)
partitions = [0,1,0,2]
result = tf.dynamic_partition(x2, partitions, 3)
with tf.Session() as sess:
r = sess.run(result)
print(r)