1.自定义方法
...
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def weights_init_normal2(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.xavier_normal_(m.weight, gain=1)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
def weights_init_normal3(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.kaiming_uniform_(m.weight.data, mode='fan_in', nonlinearity='relu')
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
2.打印
def print_weight(m):
if isinstance(m, nn.Conv2d):
print("weight", m.weight.data)
print("bias:", m.bias)
print("next...")
model = XXXnet()
model.apply(weights_init_normal2)
model.apply(print_weight)
注:model.apply(fn)方法会递归地将fn应用于model的子module比如Conv,Linear层最后再应用于整个model