1. pytorch提供接口
method 1 torch.nn.init里面有很多初始化分布
1 import torch.nn.init as init 2 3 self.conv1 = nn.Conv2d(3, 20, 5, stride=1, bias=True) 4 init.xavier_uniform(self.conv1.weight, gain=np.sqrt(2.0)) 5 init.constant(self.conv1.bias, 0.1)
method 2 http://pytorch.org/docs/master/nn.html
1 def init_weights(m): 2 print(m) 3 if isinstance(m, nn.Linear): 4 m.weight.data.fill_(1.0) 5 print(m.weight) 6 7 net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 8 net.apply(init_weights)
conv.py中有定义函数
1 def reset_parameters(self): 2 n = self.in_channels 3 for k in self.kernel_size: 4 n *= k 5 stdv = 1. / math.sqrt(n) 6 self.weight.data.uniform_(-stdv, stdv) 7 if self.bias is not None: 8 self.bias.data.uniform_(-stdv, stdv)