编程速记(27): Pytorch篇-纠错'Res_rec' object has no attribute '_parameters'——基于nn.Module的网络的搭建

一、报错:‘Res_rec’ object has no attribute ‘_parameters’

在这里插入图片描述

二、纠错前代码

class Res_com(nn.Module):
    def __init__(self, n_com=3,  b_com=6, d_com=2, com_disable=False):
        def make(n, b, d):
            f1 = [2 ** (b + i) for i in range(n)] + [2 ** (b + n - 1 - i) for i in range(n)]
            f2 = [i * (2 ** d) for i in f1]
            del f2[len(f2) // 2]
            f2_last = 32 if f1[0] > 32 else 16
            f2.append(f2_last)
            return f1, f2

        self.f1_com, self.f2_com = make(n_com, b_com, d_com)
        self.n_com = n_com
        self.com_disable = com_disable
    def forward(self,x):
        if self.com_disable:
            print("No implementation for com_disable")
            raise
        else:
            print("in_channels for class Res_com is :{}".format(x[1]))
            for i in range(self.n_com * 2):
                x = Res_block(x[1], [self.f1_com[i], self.f1_com[i], self.f2_com[i]])(x)
            out = nn.Conv2d(in_channels=x[1],out_channels=12,kernel_size=1,stride=1,padding=0)(x)
            return out

出错分析

网络中涉及到参数的nn.Conv2d,BatchNorm2d等都被放到了forward,致使程序运行中会认为该网络没有可以更新的权重参数从而出错

解决方案

正确的做法是:将所有涉及到参数更新的层全部以类的变量成员的形式定义在类的初始化方法__init__方法中。

三、纠错后代码

class Res_com(nn.Module):
    def __init__(self, n_com=3,  b_com=6, d_com=2, com_disable=False,in_channels=3):
        def make(n, b, d):
            f1 = [2 ** (b + i) for i in range(n)] + [2 ** (b + n - 1 - i) for i in range(n)]
            f2 = [i * (2 ** d) for i in f1]
            del f2[len(f2) // 2]
            f2_last = 32 if f1[0] > 32 else 16
            f2.append(f2_last)
            return f1, f2

        self.f1_com, self.f2_com = make(n_com, b_com, d_com)
        self.n_com = n_com
        self.com_disable = com_disable

        layers = []
        if self.com_disable:
            print("No implementation for com_disable")
            raise
        else:
            for i in range(self.n_com * 2):
                if i == 0 :
                    layers.append(Res_block(in_channels, [self.f1_com[i], self.f1_com[i], self.f2_com[i]]))
                else:
                    layers.append(Res_block(self.f2_com[i-1], [self.f1_com[i], self.f1_com[i], self.f2_com[i]]))
            self.multi_res_block = nn.sequential(*layers)
            self.conv1 = nn.Conv2d(in_channels=x[1],out_channels=12,kernel_size=1,stride=1,padding=0)(x)

    def forward(self,x):
        print("in_channels for class Res_com is :{}".format(x[1]))
        out = self.multi_res_block(x)
        out = self.conv1(out)
        return out
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值