PyTorch入门:以MNIST手写数字识别为例

PyTorch入门:以MNIST手写数字识别为例

引言

在人工智能和深度学习的世界里,PyTorch已经成为最受欢迎的框架之一。它不仅强大灵活,而且易于学习和使用。今天,我们将通过一个经典的机器学习问题——MNIST手写数字识别,来入门PyTorch。无论你是深度学习新手,还是希望迁移到PyTorch的开发者,这篇文章都将为你提供一个良好的起点。

PyTorch简介

PyTorch是由Facebook的AI研究团队开发的开源机器学习库。它以其动态计算图和直观的Python式语法而闻名。与其他框架相比,PyTorch的学习曲线相对平缓,这使得它成为初学者的理想选择。

MNIST数据集

MNIST是机器学习领域最著名的数据集之一。它包含70,000张28x28像素的手写数字图像,其中60,000张用于训练,10,000张用于测试。我们的目标是训练一个模型,能够准确识别这些手写数字。

开始使用PyTorch

1. 安装PyTorch

首先,我们需要安装PyTorch。可以使用pip命令进行安装:

pip install torch torchvision

2. 导入必要的库

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

3. 加载MNIST数据集

PyTorch提供了方便的工具来加载MNIST数据集:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

4. 定义神经网络模型

我们将创建一个简单的全连接神经网络:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()

5. 定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

6. 训练模型

for epoch in range(5):  # 训练5个epoch
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')

7. 测试模型

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'准确率: {100 * correct / total}%')

结论

通过这个简单的MNIST手写数字识别例子,我们完成了PyTorch的基本入门。我们学习了如何加载数据、定义模型、设置损失函数和优化器、训练模型并进行测试。这只是PyTorch强大功能的冰山一角,但它为进一步探索深度学习和PyTorch奠定了基础。

随着你对PyTorch的深入学习,你会发现它在处理更复杂的任务时也同样得心应手。无论是计算机视觉、自然语言处理还是强化学习,PyTorch都能为你提供所需的工具和灵活性。

希望这篇入门指南能激发你对PyTorch和深度学习的兴趣。记住,实践是掌握这些技能的关键。继续探索,不断尝试新的项目,你会发现PyTorch的魅力所在。祝你在机器学习的旅程中取得成功!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

逸巽散人

不请我喝个咖啡吗?(笑)

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值