一、访问模型参数:
import torch
from torch import nn
from torch.nn import init
net = nn.Sequential(nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 1)) # pytorch已进行默认初始化
上一节说了,Sequential类继承自Module类,对于Sequential实例中含模型参数的层,我们可以通过Module类的parameters()或者named_parameters方法来访问所有参数。
比如对于上面的net:
for name, param in net.named_parameters():
print(name, param.size(), type(param))
输出:
0.weight torch.Size([3, 4]) <class 'torch.nn.parameter.Parameter'>
0.bias torch.Size([3]) <class 'torch.nn.parameter.Parameter'>
2.weight torch.Size([1, 3]) <class 'torch.nn.parameter.Parameter'>
2.bias torch.Size([1]) <class 'torch.nn.parameter.Parameter'>
我们再用named_parameters函数试试:
for name, param in net.named_parameters():
print(name, param.size(), type(param))
输出:
0.weight torch.Size([3, 4]) <class 'torch.nn.parameter.Parameter'>
0.bias torch.Size([3]) <class 'torch.nn.parameter.Parameter'>
2.weight torch.Size([1, 3]) <class 'torch.nn.parameter.Parameter'>
2.bias torch.Size([1]) <class 'torch.nn.parameter