tensorflow中tf.concat的axis的使用我一直理解的比较模糊,这次做个笔记理下自己的思路。
import tensorflow as tf
tf.enable_eager_execution()
import numpy as np
先生成两个矩阵m1, 和m2, 大小为两行三列
m1 = np.random.rand(2,3) # m1.shape (2,3)
m1
>>array([[0.44529968, 0.42451167, 0.07463199],
[0.35787143, 0.22926186, 0.34583839]])
m2 = np.random.rand(2,3) # m2.shape (2,3)
m2
>>array([[0.92811531, 0.6180391 , 0.71969461],
[0.00564108, 0.55381637, 0.17155987]])
接下来采用tf.concat进行连接,简单来说,axis=0实际就是按行拼接,axis=1就是按列拼接
# axis = 0
m3 = tf.concat([m1,m2],axis=0)
m3
>> array([[0.44529968, 0.42451167, 0.07463199],
[0.35787143, 0.22926186, 0.34583839],
[0.92811531, 0.6180391 , 0.71969461],
[0.00564108, 0.55381637, 0.17155987]])
m3.shape
>> (4,3