nn.Sequential本质上是新定义了一个网络,这个网络里面有天然存在的输入输出继承关系。我们可以通过nn.Sequential的源码看到,其自带的forward() 函数不支持传递多个参数。经过查看我们所构建的网络的源码,发现该模型里面有子模块DWCov,forward里面需要传入多个参数,故此不能使用nn.Sequential,所以会报错。
DWCov代码如下:
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x