nn.Module
是 PyTorch 中所有神经网络的基类,几乎所有的神经网络模型都需要从 nn.Module
继承。它封装了神经网络的基本结构,包括层的定义、参数管理、前向传播等功能,使得构建、训练和优化深度学习模型变得方便。以下是 nn.Module
的主要功能和作用:
1. 网络层的容器
nn.Module
是一个用于存储和组织网络层的容器。通过继承 nn.Module
,你可以将各种网络层(如卷积层、全连接层、批量归一化层等)添加到模型中。它能够自动将这些层的参数注册到模型中,方便后续调用和训练。
例如:
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3) # 定义一个卷积层
self.fc1 = nn.Linear(16*6*6, 10) # 定义一个全连接层
2. 前向传播逻辑(forward
方法)
在继承 nn.Module
后,你需要重写 forward
方法,以定义数据如何在模型中流动。这个方法是必须实现的,表示输入如何通过各个网络层并最终得到输出。
例如:
python
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3)
self.fc1 = nn.Linear(16*6*6, 10)
def forward(self, x):
x = self.conv1(x) # 数据流过卷积层
x = x.view(x.size(0), -1) # 展平
x = self.fc1(x) # 数据流过全连接层
return x # 输出结果
3. 参数管理
nn.Module
能够自动跟踪所有层中的参数,如权重和偏置。这些参数会被存储在 model.parameters()
中,方便后续用于优化器和梯度更新。
例如:
model = MyNet()
for param in model.parameters():
print(param) # 打印模型中的权重和偏置
4. 模型训练与推理模式
nn.Module
提供了 train()
和 eval()
方法,分别用于设置模型为训练模式和推理(评估)模式。这会影响到一些特定层(如 Dropout 和 BatchNorm)的行为。
model.train()
:启用训练模式,Dropout 等层会保留随机性。model.eval()
:启用推理模式,Dropout 会关闭,BatchNorm 的均值和方差会固定。
model.train() # 训练模式
model.eval() # 推理模式
5. 模块嵌套
nn.Module
还支持模块的嵌套。你可以在一个 nn.Module
中嵌套其他 nn.Module
,从而构建复杂的网络。
例如:
class Block(nn.Module):
def __init__(self):
super(Block, self).__init__()
self.conv = nn.Conv2d(3, 16, 3)
def forward(self, x):
return self.conv(x)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.block1 = Block()
self.block2 = Block()
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
return x
6. 模型参数的保存与加载
通过 nn.Module
,你可以方便地保存和加载模型的权重。
- 保存权重:
torch.save(model.state_dict(), "model.pth")
- 加载权重:
model.load_state_dict(torch.load("model.pth"))
总结
nn.Module
是 PyTorch 框架中构建神经网络的核心基类。它帮助你组织网络结构、管理模型参数、定义前向传播逻辑,并为模型的训练和推理提供便利。