- 初始化方法在 torch.nn.init 中
- 实例化模型后,在模型init函数中初始化权重
初始化函数:
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)
#nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
#nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
#nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight.data, 0, 0.01)
# m.weight.data.normal_(0,0.01)
m.bias.data.zero_()
先从self.modules()中遍历每一层,判断各层属于什么类型,是否是Conv2d,是否是BatchNorm2d,是否是Linear的,然后根据不同类型的层,设定不同的权值初始化方法,例如Xavier,kaiming,normal_等等。