nn.Module 是 PyTorch 中所有神经网络的基类

nn.Module 是 PyTorch 中所有神经网络的基类,几乎所有的神经网络模型都需要从 nn.Module 继承。它封装了神经网络的基本结构,包括层的定义、参数管理、前向传播等功能,使得构建、训练和优化深度学习模型变得方便。以下是 nn.Module 的主要功能和作用:

1. 网络层的容器

nn.Module 是一个用于存储和组织网络层的容器。通过继承 nn.Module,你可以将各种网络层(如卷积层、全连接层、批量归一化层等)添加到模型中。它能够自动将这些层的参数注册到模型中,方便后续调用和训练。

例如:

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)  # 定义一个卷积层
        self.fc1 = nn.Linear(16*6*6, 10)  # 定义一个全连接层

2. 前向传播逻辑(forward 方法)

在继承 nn.Module 后,你需要重写 forward 方法,以定义数据如何在模型中流动。这个方法是必须实现的,表示输入如何通过各个网络层并最终得到输出。

例如:

python

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3)
        self.fc1 = nn.Linear(16*6*6, 10)

    def forward(self, x):
        x = self.conv1(x)   # 数据流过卷积层
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc1(x)     # 数据流过全连接层
        return x            # 输出结果

3. 参数管理

nn.Module 能够自动跟踪所有层中的参数,如权重和偏置。这些参数会被存储在 model.parameters() 中,方便后续用于优化器和梯度更新。

例如:

model = MyNet()
for param in model.parameters():
    print(param)  # 打印模型中的权重和偏置

4. 模型训练与推理模式

nn.Module 提供了 train()eval() 方法,分别用于设置模型为训练模式和推理(评估)模式。这会影响到一些特定层(如 Dropout 和 BatchNorm)的行为。

  • model.train():启用训练模式,Dropout 等层会保留随机性。
  • model.eval():启用推理模式,Dropout 会关闭,BatchNorm 的均值和方差会固定。
model.train() # 训练模式
model.eval() # 推理模式

5. 模块嵌套

nn.Module 还支持模块的嵌套。你可以在一个 nn.Module 中嵌套其他 nn.Module,从而构建复杂的网络。

例如:

class Block(nn.Module):
    def __init__(self):
        super(Block, self).__init__()
        self.conv = nn.Conv2d(3, 16, 3)
        
    def forward(self, x):
        return self.conv(x)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.block1 = Block()
        self.block2 = Block()

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        return x

6. 模型参数的保存与加载

通过 nn.Module,你可以方便地保存和加载模型的权重。

  • 保存权重torch.save(model.state_dict(), "model.pth")
  • 加载权重model.load_state_dict(torch.load("model.pth"))

总结

nn.Module 是 PyTorch 框架中构建神经网络的核心基类。它帮助你组织网络结构、管理模型参数、定义前向传播逻辑,并为模型的训练和推理提供便利。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值