如何用 PyTorch 构建和训练一个基本的神经网络
在开始之前,我们先了解一下 PyTorch 的几个关键概念:(初学者不需要深入,知道有这些东西即可)
- 张量 (Tensor): 张量是 PyTorch 中的基本数据结构,类似于 NumPy 的 ndarray,但它们也可以在 GPU 上进行运算。
- 自动微分 (Autograd): PyTorch 提供了强大的自动微分功能,允许你通过简单的前向传播自动计算梯度,这是训练神经网络的核心。
- 神经网络 (nn.Module): PyTorch 的 torch.nn 模块提供了构建神经网络的基础模块,包括各种层(例如全连接层、卷积层)和激活函数。
- 优化器 (Optimizer): torch.optim 模块包含了多种优化算法(如 SGD、Adam),用于更新模型参数以最小化损失函数。
- 数据加载 (DataLoader): torch.utils.data.DataLoader 是一个工具,帮助你将数据集分批(batch)加载到模型中,尤其是当数据集很大时。
首先导入需要用的的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
友情提醒:如果你使用的不同的优化器和损失函数,会对参数的类型有不同的要求。如果你使用了和例子中不同的损失函数等发生了类型或者其他错误,请上网查找解决方案。一般都是由于参数类型不符合,将 tensor 的类型进行转换即可。
1. 定义数据
创建一些随机数据作为输入和标签
X = torch.randn(100, 3) # 100个样本,每个样本有3个特征
y = torch.randn(100, 1) # 100个样本的目标值
将数据集转换为 TensorDataset,并使用 DataLoader 加载
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
2. 定义模型
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(3, 5) # 全连接层,输入3个特征,输出5个特征
self.fc2 = nn.Linear(5, 1) # 全连接层,输入5个特征,输出1个特征
def forward(self, x):
x = torch.relu(self.fc1(x)) # ReLU 激活函数
x = self.fc2(x)
return x
model = SimpleNN()
3. 定义损失函数和优化器
criterion = nn.MSELoss() # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01) # 随机梯度下降优化器
4. 训练模型
for epoch in range(100): # 训练100个epoch
for batch_X, batch_y in dataloader:
optimizer.zero_grad() # 梯度清零
output = model(batch_X) # 前向传播
loss = criterion(output, batch_y) # 计算损失
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新模型参数
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
5. 测试模型(示例)
with torch.no_grad(): # 在测试时不需要计算梯度
test_X = torch.tensor([[0.5, -0.2, 0.1]]) # 输入一个测试样本
test_output = model(test_X)
print(f'Test Output: {test_output.item()}')
输出:
Epoch 1, Loss: 0.7683534622192383
Epoch 2, Loss: 0.4891282021999359
Epoch 3, Loss: 0.5800896883010864
Epoch 4, Loss: 1.342455267906189
Epoch 5, Loss: 0.5247911810874939
Epoch 6, Loss: 0.8952728509902954
Epoch 7, Loss: 0.829267144203186
Epoch 8, Loss: 1.3714759349822998
Epoch 9, Loss: 0.8819470405578613
Epoch 10, Loss: 0.8902752995491028
Epoch 11, Loss: 1.0910319089889526
Epoch 12, Loss: 1.5810422897338867
Epoch 13, Loss: 1.2618199586868286
Epoch 14, Loss: 0.7619099617004395
Epoch 15, Loss: 0.7575093507766724
Epoch 16, Loss: 1.0389002561569214
Epoch 17, Loss: 0.7607104182243347
Epoch 18, Loss: 1.107161521911621
Epoch 19, Loss: 1.039939284324646
Epoch 20, Loss: 0.9904875755310059
Epoch 21, Loss: 1.8228839635849
Epoch 22, Loss: 1.0611368417739868
Epoch 23, Loss: 1.0013045072555542
Epoch 24, Loss: 1.3398150205612183
Epoch 25, Loss: 1.2530044317245483
...
Epoch 98, Loss: 1.3919637203216553
Epoch 99, Loss: 0.7074191570281982
Epoch 100, Loss: 0.9921405911445618
Test Output: -0.11975270509719849
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...