模型参数的访问
- 通过
Module
类的parameters()
或者named_parameters()
方法来访问所有参数(以迭代器的形式返回),后者除了返回参数Tensor
外还会返回其名字。 - 对于使用
Sequential
类构造的神经网络,我们可以通过方括号[]
来访问网络的任一层。 param
的类型为torch.nn.parameter.Parameter
,其实这是Tensor
的子类,和Tensor
不同的是如果一个Tensor
是Parameter
,那么它会自动被添加到模型的参数列表里
初始化模型参数
PyTorch中nn.Module
的模块参数都采取了较为合理的初始化策略,PyTorch的init
模块里提供了多种预设的初始化方法。也可以自定义初始化方法
共享模型参数
Module
类的forward
函数里多次调用同一个层。- 如果我们传入
Sequential
的模块是同一个Module
实例的话参数也是共享的
import torch
from torch import nn
from torch.nn import init
class MyModel(nn.Module):
def __init__(self, **kwargs):
super(MyModel, self).__init__(**kwargs)