通常使用pytorch在一个类的__init__()函数中声明完网络结构后,pytorch会自动初始化待训练的网络结构的权值。但这种初始化过程是随机的,参数分布没有规律且相差较大,使得网络收敛速度下降。因此,我们手动初始化权重,可以采用服从正态分布的数据来初始化权重。
1. 方法一:先定义网络,后初始化权重
def weights_init_normal(m): # 初始化权重
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
class ConvBlockInit(nn.Module): # 定义网络结构
def __init__(self, in_channels, out_channels):
super(ConvBlockInit, self).__init__()
self.init_conv = nn.Sequential(OrderedDict([
("conv0", nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=False)),
("batchnorm0", nn.BatchNorm2d(out_channels)),
("relu0", nn.ReLU(inplace=True))
]))
def forward(self, x):
return self.init_conv(x)
net = ConvBlockInit(64, 256)
net.apply(weight_init) # 加载权重
'''
报错如下:
torch.nn.modules.module.ModuleAttributeError: 'ConvBlockInit' object has no attribute 'weight'
'''
上述初始化网络权重的方法会产生错误torch.nn.modules.module.ModuleAttributeError: ‘ConvBlockInit’ object has no attribute ‘weight’。
报错原因: apply()函数会递归的对该网络结构的所有children结构应用权重初始化条件,同时也对该网络结构应用初始化参数,然而ConvBlockInit不具有weight这个属性,所以报错。
我也把ConvBlockInit对应的m.class.__name__输出了一下,如下:
所以,不推荐上述方法,可以参照下边的参数初始化方法。
def weights_init_normal(m): # 初始化权重
if isinstance(m, nn.Conv2d):
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
2. 方法二:在声明网络结构时,就初始化权重
class ConvBlockInit(nn.Module): # 定义网络结构
def __init__(self, in_channels, out_channels):
super(ConvBlockInit, self).__init__()
self.init_conv = nn.Sequential(OrderedDict([
("conv0", nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=1, padding=3, bias=False)),
("batchnorm0", nn.BatchNorm2d(out_channels)),
("relu0", nn.ReLU(inplace=True))
]))
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def forward(self, x):
return self.init_conv(x)
注意:如果在声明网络结构时,就初始化权重,那么就不需要再使用net.apply()了