nn.Sequential()引起的 forward() takes 1 positional argument but 2 were given

最近在训练模型时,想要将模型的分类层去除,输出模型的特征图,于是进行如下操作去除模型的最后两层结构,然后奇怪的事情就发生了,运行时程序老是报错, forward() takes 1 positional argument but 2 were given!!!

class PvT(nn.Module):
    def __init__(self):  # num_classes,此处为 二分类值为2
        super().__init__()
        #创建模型,并且加载预训练参数
        net= pvt_v2.PyramidVisionTransformerV2()
        #去除分类层
        self.feature = nn.Sequential(*list(net.children())[:-2])

    def forward(self, x):
        x = self.feature(x)
        B, N, C = x.shape
        N = int(N ** 0.5)
        x = rearrange(x, 'b (h w) c -> b c h w', h=N, w=N)
        return x

于是在网上搜寻答案,常见的几个回答有三个:

一. __init__拼写问题

这个就是注意init前后都有两个短的下划线,这个问题好解决,自己对照着改改就好,别整错了

二. 调用对象问题

python调用类时,需要先将类实例化,再给对象传入参数

错误示例:

net=PvT(input)

正确示例:

net=PvT()
net(input)

三. 就是forward需要传入的是一个参数,确实传了两个进来

net=PvT()
net(input1, input2)

如果检查以上三个都没有问题了,那就需要考虑是不是 nn.Sequential()引起的问题

nn.Sequential本质上定义了一个网络,这个网络里面有天然存在的输入输出继承关系。我们可以通过nn.Sequential的源码看到,其自带的forward() 函数不支持传递多个参数。经过查看PvT的源码,发现该模型里面有子模块DWCov,forward里面需要传入多个参数,故此不能使用nn.Sequential,所以会报错。

class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)
        return x

最后直接将PvT源码拿过来用,将最后的head层去除,不使用nn.Sequential,问题得以解决。

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
这个错误是因为在调用Net.forward()时传递了两个参数,但是forward()函数只接受一个参数。根据引用和引用的内容,可能是由于使用了nn.Sequential()或者对模型进行了修改导致的。 在引用中提到,如果使用nn.Sequential()定义了一个网络,它的forward()函数不支持传递多个参数。而在引用中,作者想要去除模型的最后两层结构,然后运行时出现了这个错误。 此外,根据引用中的代码,PvT模型中的子模块DWConv的forward()函数需要传入多个参数,这也可能是造成这个错误的原因。 要解决这个错误,你可以尝试以下几种方法: 1. 检查你是否正确传递了参数到forward()函数中,并确保只传递一个参数。 2. 如果你使用了nn.Sequential(),考虑使用其他方式定义网络,例如使用nn.ModuleList()来手动定义网络层。 3. 如果你对模型进行了修改,确保修改后的模型的forward()函数接受正确的参数。 4. 如果使用了子模块函数,确保子模块的forward()函数正确接受和处理参数。 根据你的具体情况,你可能需要结合上述方法进行调试和修改代码。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [nn.Sequential()引起forward() takes 1 positional argument but 2 were given](https://blog.csdn.net/qq_24193303/article/details/124120415)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值