torch.nn.init
weight.data.fill_(1)
bias.data.fill_(0)
weight.data.uniform_(-stdv, stdv)
1.
params = list(net.parameters())
2.
conv2params = list(net.conv2.parameters())
kernels conv2params[0]
bias conv2params[1]
3.
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
4.
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
5.
def weight_init(m):
if isinstance(m, nn.Linear):
size = m.weight.size()
fan_out = size[0] # number of rows
fan_in = size[1] # number of columns
variance = np.sqrt(2.0/(fan_in + fan_out))
m.weight.data.normal_(0.0, variance)
net = Residual() # generate an instance network from the Net class
net.apply(weights_init) # apply weight init
The apply
function will search recursively for all the modules inside your network, and will call the function on each of them. So allLinear
layers you have in your model will be initialized using this one call.
6.
- If you want to load a model's
state_dict
into another model (for example to fine-tune a pre-trained network),load_state_dict
was strict on matching the key names of the parameters. Now we provide astrict=False
option toload_state_dict
where it only loads in parameters where the keys match, and ignores the other parameter keys.
---------------------------------------------------reference--------------------------------
1. https://discuss.pytorch.org/t/weight-initilzation/157/2