我对nd.concat()函数的理解(记录一下,以备后面查阅)
例如:
a=[[1,2],[3,4]]
b=[[6,7],[8,9]]
a=nd.array(a)
b=nd.array(b)
print(nd.concat(a,b,dim=0))
输出是:
[[1. 2.]
[3. 4.]
[6. 7.]
[8. 9.]]
<NDArray 4x2 @cpu(0)>
一般来讲dim=0,说的是行,dim=1,说的是列。
因为按行拼接的时候把行看成一个整体(例如a矩阵的[1,2],[3,4]是两个整体),a,b矩阵都是2x2矩阵,拼接之后就会有4行。
简单的理解就是把b放到矩阵a的下面(如图),然后形成concat后的矩阵c。
下面把dim改为1
a=[[1,2],[3,4]]
b=[[6,7],[8,9]]
a=nd.array(a)
b=nd.array(b)
nd.concat(a,b,dim=1)
输出:
[[1. 2. 6. 7.]
[3. 4. 8. 9.]]
<NDArray 2x4 @cpu(0)>
按照刚才的理解,肯定有4列
把“列”看成一个整体(例如a矩阵的[1,3],[2,4]是两个整体)
再简单来说就是把b矩阵放到a的右边,形成的矩阵
假如我我改一下数组
a=[[1,2],[3,4]]
b=[[6,7],[8,9],[3,3]]
a=nd.array(a)
b=nd.array(b)
nd.concat(a,b,dim=1)
我先按照咱们的图理解一下(明显的不是我们认识的矩阵)
接着跑一下代码,果然,报错了
感谢看到这里😀