昨天通过画图理解了tf.reduce系列函数,今天正好又碰到了tf.concat()函数,跟昨天思路一样,又画了张图来直观理解。
首先,tf.concat()函数是用来拼接两个矩阵的,参数包括:values, axis, name='concat' ,values就是要拼接的两个矩阵,axis就是维度,一会上图解释,name就是指令名称。那么怎样拼接依靠的还是axis这个参数。同tf.reduce 类似,根据矩阵维度的不同,axis取值不一。
下面开始:
如果所示,最左边是待拼接的两个三维矩阵,上面为a,下图为b,形状都是(4, 2, 3)。
当axis=0时,拼接后的形状为(8, 2, 3),结果如右上图:相当于原形状的后两个维度不变,第一个维度相加,体现出来的就是两个矩阵的最外层括号内的逗号分割的元素进行拼接(相当于昨日两个红中括号内的绿中括号们叠拼在一起);
当axis=1时,拼接后的形状为(4, 4, 3),结果如右中图:相当于原形状的第一和第三个维度不变,中间维度相加,体现出来的就是中间层括号内的两个矩阵元素进行拼接(相当于昨日两个矩阵对应的每一个绿色中括号的元素拼在一起 );
当axis=2时,拼接后的形状为(4, 2, 6)