tf.concat()
作用:沿着某一维度合并张量
举例:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
# 当axis = 0时
print(tf.concat([t1, t2], 0))
# [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
# 当axis = 1时
print(tf.concat([t1, t2], 1))
# [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
在举个例子:
# tensor t3 with shape [2, 3]
# tensor t4 with shape [2, 3]
tf.shape(tf.concat([t3, t4], 0))
# [4, 3]
tf.shape(tf.concat([t3, t4], 1))
# [2, 6]
可以看出来,按照axis = 0 时得到的张量shape就是将第一个位置上的两个2相加,后者不变。axis = 1 时得到的张量的shape就是将第二个位置上的两个3相加,前者不变。
更加普遍的有:
V a l u e s [ i ] . s h a p e = [ D 0 , D 1 , . . . D a x i s ( i ) , . . . D n ] Values[i].shape = [D_0, D_1, ... D_{axis(i)}, ...D_n] Values[i].shape=[D0,D1,...Daxis(i),...Dn]
得到的结果的
s h a p e = [ D 0 , D 1 , . . . R a x i s , . . . D n ] shape = [D_0, D_1, ... R_{axis}, ...D_n] shape=[D0,D1,...Raxis,...Dn]
其中
R a x i s = s u m ( D a x i s ( i ) ) R_{axis} = sum(D_{axis(i)}) Raxis=sum(Daxis(i))