PyTorch 微调代码完整示例:从模型训练到评估
在深度学习任务中,微调(Fine-tuning)是一个非常重要的步骤,尤其是在迁移学习中。本文将详细介绍如何使用 PyTorch 实现一个完整的微调过程,包括模型定义、数据加载、训练循环、损失计算、反向传播、梯度更新以及模型评估。
1. 环境准备
在开始之前,请确保已经安装 PyTorch。如果尚未安装,可以通过以下命令安装:
pip install torch
2. 代码实现
以下是完整的 PyTorch 微调代码示例,包含模型训练和评估的完整流程。
2.1 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
2.2 自定义数据集
我们首先定义一个简单的自定义数据集类 CustomDataset
,用于加载数据。
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
2.3 定义模型
接下来,我们定义一个简单的线性模型 SimpleModel
。
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1) # 假设输入特征维度为10,输出为1
def forward(self, x):
return self.fc(x)
2.4 初始化模型、损失函数和优化器
我们使用均方误差损失(MSE Loss)和随机梯度下降(SGD)优化器。
model = SimpleModel()
criterion = nn.MSELoss() # 使用均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)
2.5 准备数据
假设我们有一些随机生成的数据和标签。
data = torch.randn(100, 10) # 100个样本,每个样本10个特征
labels = torch.randn(100, 1) # 100个样本,每个样本1个标签
# 创建数据集和数据加载器
dataset = CustomDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
2.6 训练循环
在训练循环中,我们依次完成前向传播、损失计算、反向传播和参数更新。
num_epochs = 10
for epoch in range(num_epochs):
model.train() # 设置模型为训练模式
for batch_data, batch_labels in dataloader:
# 前向传播
outputs = model(batch_data)
loss = criterion(outputs, batch_labels) # 计算损失
# 反向传播和优化
optimizer.zero_grad() # 清零梯度
loss.backward() # 反向传播
optimizer.step() # 更新参数
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
2.7 模型评估
在训练完成后,我们可以使用模型进行预测。注意,在评估模式下需要禁用梯度计算。
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 不计算梯度
test_data = torch.randn(10, 10) # 假设我们有10个测试样本
predictions = model(test_data)
print(predictions)
3. 代码说明
- 模型定义:
SimpleModel
是一个简单的线性模型,适用于输入特征维度为 10 的任务。 - 数据集和数据加载器:
CustomDataset
用于加载自定义数据,DataLoader
用于批量加载数据。 - 训练循环:
model.train()
设置模型为训练模式。optimizer.zero_grad()
清零梯度。loss.backward()
计算梯度。optimizer.step()
更新模型参数。
- 评估模式:
model.eval()
设置模型为评估模式,torch.no_grad()
用于在评估时不计算梯度。
4. 总结
本文提供了一个完整的 PyTorch 微调代码示例,涵盖了模型定义、数据加载、训练循环、损失计算、反向传播、梯度更新以及模型评估的完整流程。通过这个示例,你可以快速上手 PyTorch 并应用于自己的任务中。