【Pytorch实战(五)】实现MNIST手写体识别

一、MNIST数据集
二、实现MNIST手写体识别
1.借助torchvision下载数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True,
                                           transform=torchvision.transforms.ToTensor(), download=True) # 下载训练集
test_dataset = torchvision.datasets.MNIST(root='./data', train=False,
                                          transform=torchvision.transforms.ToTensor(), download=True) # 下载测试集

其中transform=torchvision.transforms.ToTensor()将数据转换为张量

2.借助DataLoader加载数据集
batch_size = 100  # 每批加载100条数据

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)  # 不打乱顺序
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
3.训练模型并保存权重文件
fc = torch.nn.Linear(28 * 28, 10) # 数据集图片大小为28*28;10对应0-9这种结果
L = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(fc.parameters(), lr=0.001)

num_epoch = 5
for epoch in range(num_epoch):
    for idx, (images, labels) in enumerate(train_loader):
        x = images.reshape(-1, 28*28)

        optimizer.zero_grad()
        preds = fc(x)
        loss = L(preds, labels)
        loss.backward()
        optimizer.step()

        if idx % 100 == 0:
            print('{} epochs, loss={}'.format(epoch, loss))

torch.save(fc.state_dict(), './mnist_model.pth')
  • 这里并未使用CNN网络,只借助torch.nn.Linear构造简单的模型
4.测试准确率
correct = 0
total = 0
fc.load_state_dict(torch.load('./mnist_model.pth'))
for images, labels in test_loader:
    x = images.reshape(-1, 28*28)
    preds = fc(x)
    predicted = torch.argmax(preds, 1)  # 最大值索引正好对应于0-9预测值
    total += labels.size(0)
    correct += (predicted==labels).sum().item()

accuracy = correct / total
print('correct={},total={},accuracy={:.1%}'.format(correct, total, accuracy))
因为这里的模型比较简单,最终准确率约为92.2%。
  • 若需借助CNN则可参照如下代码搭建卷积神经网络:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv0 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(stride=2, kernel_size=2)
        self.fc5 = nn.Linear(128 * 14 * 14, 1024)
        self.relu6 = nn.ReLU()
        self.dropout7 = nn.Dropout(p=0.5)
        self.fc8 = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.conv0(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu3(x)
        x = self.pool4(x)
        x = x.view(-1, 128 * 14 * 14)
        x = self.fc5(x)
        x = self.relu6(x)
        x = self.dropout7(x)
        x = self.fc8(x)

        return x
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值