Pytorch学习(二)Pytorch对声明的网络结构进行初始化及‘ConvBlockInit‘ object has no attribute ‘weight‘错误出现原因分析

通常使用pytorch在一个类的__init__()函数中声明完网络结构后,pytorch会自动初始化待训练的网络结构的权值。但这种初始化过程是随机的,参数分布没有规律且相差较大,使得网络收敛速度下降。因此,我们手动初始化权重,可以采用服从正态分布的数据来初始化权重。

1. 方法一:先定义网络,后初始化权重
def weights_init_normal(m):     # 初始化权重
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

class ConvBlockInit(nn.Module):    # 定义网络结构
    def __init__(self, in_channels, out_channels):
        super(ConvBlockInit, self).__init__()
        self.init_conv = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=False)),
            ("batchnorm0", nn.BatchNorm2d(out_channels)),
            ("relu0", nn.ReLU(inplace=True))
        ]))

    def forward(self, x):
        return self.init_conv(x)

net = ConvBlockInit(64, 256) 
net.apply(weight_init)  # 加载权重

'''
报错如下:
torch.nn.modules.module.ModuleAttributeError: 'ConvBlockInit' object has no attribute 'weight'
'''

上述初始化网络权重的方法会产生错误torch.nn.modules.module.ModuleAttributeError: ‘ConvBlockInit’ object has no attribute ‘weight’
报错原因: apply()函数会递归的对该网络结构的所有children结构应用权重初始化条件,同时也对该网络结构应用初始化参数,然而ConvBlockInit不具有weight这个属性,所以报错。
我也把ConvBlockInit对应的m.class.__name__输出了一下,如下:
在这里插入图片描述
所以,不推荐上述方法,可以参照下边的参数初始化方法。

def weights_init_normal(m):     # 初始化权重
     if isinstance(m, nn.Conv2d):
         torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
     elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
         torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
         torch.nn.init.constant_(m.bias.data, 0.0)
2. 方法二:在声明网络结构时,就初始化权重
class ConvBlockInit(nn.Module):    # 定义网络结构
    def __init__(self, in_channels, out_channels):
        super(ConvBlockInit, self).__init__()
        self.init_conv = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=False)),
            ("batchnorm0", nn.BatchNorm2d(out_channels)),
            ("relu0", nn.ReLU(inplace=True))
        ]))
	    for m in self.modules():
			if isinstance(m, nn.Conv2d):
				torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
			elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
				torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
				torch.nn.init.constant_(m.bias.data, 0.0)
    def forward(self, x):
        return self.init_conv(x)

注意:如果在声明网络结构时,就初始化权重,那么就不需要再使用net.apply()了

  • 8
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值