pytorch中的nn.init模块专门为初始化设计,实现了常用的初始化策略。
import torch as t
from torch import nn
from torch.nn import init
linear = nn.Linear(3, 4)
t.manual_seed(1)
# 等价于 linear.weight.data.normal_(0, std)
init.xavier_normal(linear.weight)
# xavier_normal 用于一个正态分布生成值
for name,paras in linear.named_parameters():
print(name,paras)
OUT:
weight Parameter containing:
tensor([[ 0.3535, 0.1427, 0.0330],
[ 0.3321, -0.2416, -0.0888],
[-0.8140, 0.2040, -0.5493],
[-0.3010, -0.4769, -0.0311]])
bias Parameter containing:
tensor([-0.3732, 0.3750, 0.3505, 0.5120])
D:/pycharm/PyTorch/data.py:16: UserWarning: nn.init.xavier_normal is now deprecated in favor of nn.init.xavier_normal_.
init.xavier_normal(linear.weight)
对模型的所有参数进行初始化
for name,paras in net.named_parameters():
if name.find('linear') = -1:
paras[0] # weight
paras[1] # bias
elif name.find('conv') = -1:
pass
elif name.find('norm') = -1:
pass