`torch.nn` 是 PyTorch 库中的一个模块,专门用于构建和训练神经网络。
`torch.nn` 模块的关键组件:
1. **神经网络层**:
`torch.nn` 提供了各种神经网络层,如全连接层(`nn.Linear`)、卷积层(`nn.Conv2d`)、循环层(`nn.LSTM`)等。这些层可以方便地堆叠起来,构建复杂的神经网络模型。
2. **损失函数**:
`torch.nn` 包含了多种常用的损失函数,如均方误差损失(`nn.MSELoss`)、交叉熵损失(`nn.CrossEntropyLoss`)等,用于评估模型的性能。
3. **激活函数**:
`torch.nn` 提供了多种激活函数,如 ReLU(`nn.ReLU`)、Sigmoid(`nn.Sigmoid`)等,用于引入非线性特性到模型中。
4. **容器类**:
`torch.nn` 提供了一些容器类,如 `nn.Sequential` 和 `nn.ModuleList`,用于将多个层组合在一起,形成一个更大的模型。
5. **自定义模块**:
通过继承 `nn.Module` 类,开发者可以定义自己的神经网络模块,重写 `__init__` 和 `forward` 方法,灵活地设计模型的结构和前向传播过程。
以下是一个简单的示例,展示如何使用 `torch.nn` 模块定义一个神经网络模型:
import torch
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 50) # 输入层到隐藏层
self.fc2 = nn.Linear(50, 1) # 隐藏层到输出层
self.relu = nn.ReLU() # ReLU 激活函数
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# 创建模型实例
model = SimpleNN()