"[...]"、"[::]"与torch.cat的解析(YOLOv5 common.py)
YOLOv5中common.py的函数Focus.forward,返回为self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)),这句话的解析。
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
1. "[...]"简单的使用
A = np.arange(25).reshape(5,5)
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]])
(1) 遍历每行,2表示索引为2的所在列
A[...,2]
array([ 2, 7, 12, 17, 22])
(2)遍历每行,:2表示索引<2的0,1所在的列
A[...,:2]
array([[ 0, 1],
[ 5, 6],
[10, 11],
[15, 16],
[20, 21]])
(3)遍历每行,::2表示步长为2,即0,2,4列
A[...,::2]
array([[ 0, 2, 4],
[ 5, 7, 9],
[10, 12, 14],
[15, 17, 19],
[20, 22, 24]])
(4)::2表示取间隔步长为2的取行数据,遍历每列
A[::2,...]
array([[ 0, 1, 2, 3, 4],
[10, 11, 12, 13, 14],
[20, 21, 22, 23, 24]])
(5)相当于插入维度,即reshape(A,[1,4,4])
A[None,...]
array([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]]])
总结:"..."表示所有,"::x"表示x个步长
2. torch.cat
A = np.arange(9).reshape(3,3)
B = np.arange(1,10).reshape(3,3)
AA = torch.from_numpy(A)
BB = torch.from_numpy(B)
(1)按维度0拼接,即竖着拼接
C1 = torch.cat([AA,BB],0)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.int32)
(2)按维度1拼接,即横着拼接
C2 = torch.cat([AA,BB],1)
tensor([[0, 1, 2, 1, 2, 3],
[3, 4, 5, 4, 5, 6],
[6, 7, 8, 7, 8, 9]], dtype=torch.int32)
3. torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
X = np.arange(16*3).reshape(3,4,4)
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]],
[[32, 33, 34, 35],
[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]]])
(1)行和列步长间隔都为2
X[...,::2,::2]
array([[[ 0, 2],
[ 8, 10]],
[[16, 18],
[24, 26]],
[[32, 34],
[40, 42]]])
(2)行和列步长间隔都为2,行的起始位置为1
X[..., 1::2, ::2]
array([[[ 4, 6],
[12, 14]],
[[20, 22],
[28, 30]],
[[36, 38],
[44, 46]]])
(3)行和列步长间隔都为2,列的起始位置为1
X[..., ::2, 1::2]
array([[[ 1, 3],
[ 9, 11]],
[[17, 19],
[25, 27]],
[[33, 35],
[41, 43]]])
(4)torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
X = torch.from_numpy(X)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]],
[[32, 33, 34, 35],
[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]]], dtype=torch.int32)
torch.cat([X[..., ::2, ::2], X[..., 1::2, ::2], X[..., ::2, 1::2], X[..., 1::2, 1::2]], 1)
tensor([[[ 0, 2],
[ 8, 10],
[ 4, 6],
[12, 14],
[ 1, 3],
[ 9, 11],
[ 5, 7],
[13, 15]],
[[16, 18],
[24, 26],
[20, 22],
[28, 30],
[17, 19],
[25, 27],
[21, 23],
[29, 31]],
[[32, 34],
[40, 42],
[36, 38],
[44, 46],
[33, 35],
[41, 43],
[37, 39],
[45, 47]]], dtype=torch.int32)