代码如下:
net = nn.Linear(2, 2)
print(net.state_dict())
输出结果:
OrderedDict([
('weight', tensor([[ 0.6155, -0.4649],
[-0.1848, -0.0663]])),
('bias', tensor([-0.0265, -0.4134]))])
从输出结果可以看出,在适用nn.Linear的时候会自动给weight和bias赋初值,点进Linear后可以发现
Attributes:
weight: the learnable weights of the module of shape :math:
(
out_features
,
in_features
)
(\text{out\_features}, \text{in\_features})
(out_features,in_features). The values are initialized from :
m
a
t
h
:
U
(
−
k
,
k
)
,
w
h
e
r
e
m
a
t
h
:
k
=
1
in_features
math: \mathcal{U}(-\sqrt{k}, \sqrt{k}), where math:k = \frac{1}{\text{in\_features}}
math:U(−k,k),wheremath:k=in_features1
bias: the learnable bias of the module of shape :math:
(
out_features
)
(\text{out\_features})
(out_features).
If :attr:bias
is True
, the values are initialized from
m
a
t
h
:
U
(
−
k
,
k
)
w
h
e
r
e
m
a
t
h
:
k
=
1
in_features
math:\mathcal{U}(-\sqrt{k}, \sqrt{k}) where math:k = \frac{1}{\text{in\_features}}
math:U(−k,k)wheremath:k=in_features1
即nn.Linear会自动给初始化
自定义初始化方法如下:
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
#nn.Flatten将多维张量压缩为二维(一行)进而进行线性操作、
def init_weights(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights);
其他参数获取方法:
net = nn.Linear(2, 2)
print('net.state_dict()\n', net.state_dict())
print('net.bias\n', net.bias)
print('net.bias.data\n', net.bias.data)
print('net.bias.grad\n', net.bias.grad) # 因为没有backward所以没有赋梯度值
print('net.named_parameters()\n', net.named_parameters())
print([(name, param) for name, param in net.named_parameters()])
输出结果:
net.state_dict()
OrderedDict([('weight', tensor([[-0.6907, -0.4128],
[-0.4212, 0.0620]])), ('bias', tensor([-0.2019, -0.3731]))])
net.bias
Parameter containing:
tensor([-0.2019, -0.3731], requires_grad=True)
net.bias.data
tensor([-0.2019, -0.3731])
net.bias.grad
None
net.named_parameters()
<generator object Module.named_parameters at 0x151b22ba0>
[('weight', Parameter containing:
tensor([[-0.6907, -0.4128],
[-0.4212, 0.0620]], requires_grad=True)), ('bias', Parameter containing:
tensor([-0.2019, -0.3731], requires_grad=True))]