第一步修改是将head部分里面的Conv替换成了GSConv,GSConv函数是基于Conv函数修改的,主要的修改内容如下:
1.在model/common.py里添加GSConv的函数代码,添加的代码段如下:
#---------------自己添加的GSConv------------
class GSConv(nn.Module):
def __init__(self,c1,c2,k=1,s=1,g=1, act=True):
c_ = c2 // 2
self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
def forward(self,x):
x1 = self.cv1(x)
x2 = torch.cat((x1,self.cv2(x1)),1)
# shuffle
b, n, h, w = x2.data.size()
b_n = b * n // 2
y = x2.reshape(b_n, 2, h * w)
y = y.permute(1, 0, 2)
y = y.reshape(2, -1, n // 2, h, w)