PyTorch VGG16手写数字识别教程

手写数字识别教程:使用PyTorch和VGG16

1. 环境准备

确保你已安装以下库:

pip install torch torchvision
2. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
3. 数据预处理

我们需要对MNIST数据集进行转换,使其适合输入VGG16模型。由于VGG16的输入要求为224x224的图像,因此我们需要调整图像大小,并进行标准化处理。

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 将图像大小调整为224x224
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化处理,均值和标准差
])

# 下载并加载训练和测试数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
4. 定义VGG16模型

VGG16由多个卷积层和全连接层组成。我们将调整输入通道以适应单通道的MNIST数据。

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        # 定义卷积层
        self.vgg = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),  # 将输入通道设置为1(灰度图)
            nn.ReLU(),  # 激活函数
            nn.MaxPool2d(kernel_size=2, stride=2),  # 最大池化层,减小特征图尺寸
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        # 定义全连接层
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),  # 第一个全连接层
            nn.ReLU(),
            nn.Dropout(),  # 随机失活,防止过拟合
            nn.Linear(4096, 4096),  # 第二个全连接层
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 10)  # 输出层,10个类(数字0-9)
        )

    def forward(self, x):
        x = self.vgg(x)  # 通过卷积层
        x = x.view(x.size(0), -1)  # 展平特征图
        x = self.classifier(x)  # 通过全连接层
        return x
5. 训练模型

我们将使用交叉熵损失函数和Adam优化器,并训练模型。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 检测可用的设备
model = VGG16().to(device)  # 实例化模型并移动到设备上
criterion = nn.CrossEntropyLoss()  # 损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # 优化器

# 训练循环
for epoch in range(5):  # 训练5个epoch
    model.train()  # 设置为训练模式
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # 移动到设备
        
        optimizer.zero_grad()  # 清空梯度
        outputs = model(images)  # 前向传播
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
    
    print(f'Epoch [{epoch+1}/5], Loss: {loss.item():.4f}')  # 输出当前epoch的损失
6. 测试模型

在测试阶段,我们将计算模型的准确率。

model.eval()  # 设置为评估模式
with torch.no_grad():  # 禁用梯度计算
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)  # 移动到设备
        outputs = model(images)  # 前向传播
        _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
        total += labels.size(0)  # 统计总样本数
        correct += (predicted == labels).sum().item()  # 统计正确预测的数量

    print(f'Accuracy: {100 * correct / total:.2f}%')  # 输出准确率

总结

这个教程详细介绍了如何使用VGG16模型对MNIST数据集进行手写数字识别。通过调整网络参数和训练轮数,你可以进一步提高模型的性能。希望这个教程能帮助你更好地理解PyTorch及深度学习的应用!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

吉小雨

你的激励是我创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值