最近在训练模型时,想要将模型的分类层去除,输出模型的特征图,于是进行如下操作去除模型的最后两层结构,然后奇怪的事情就发生了,运行时程序老是报错, forward() takes 1 positional argument but 2 were given!!!
class PvT(nn.Module):
def __init__(self): # num_classes,此处为 二分类值为2
super().__init__()
#创建模型,并且加载预训练参数
net= pvt_v2.PyramidVisionTransformerV2()
#去除分类层
self.feature = nn.Sequential(*list(net.children())[:-2])
def forward(self, x):
x = self.feature(x)
B, N, C = x.shape
N = int(N ** 0.5)
x = rearrange(x, 'b (h w) c -> b c h w', h=N, w=N)
return x
于是在网上搜寻答案,常见的几个回答有三个:
一. __init__拼写问题
这个就是注意init前后都有两个短的下划线,这个问题好解决,自己对照着改改就好,别整错了
二. 调用对象问题
python调用类时,需要先将类实例化,再给对象传入参数
错误示例:
net=PvT(input)
正确示例:
net=PvT()
net(input)
三. 就是forward需要传入的是一个参数,确实传了两个进来
net=PvT()
net(input1, input2)
如果检查以上三个都没有问题了,那就需要考虑是不是 nn.Sequential()引起的问题
nn.Sequential本质上定义了一个网络,这个网络里面有天然存在的输入输出继承关系。我们可以通过nn.Sequential的源码看到,其自带的forward() 函数不支持传递多个参数。经过查看PvT的源码,发现该模型里面有子模块DWCov,forward里面需要传入多个参数,故此不能使用nn.Sequential,所以会报错。
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
最后直接将PvT源码拿过来用,将最后的head层去除,不使用nn.Sequential,问题得以解决。