机器学习预测-CNN手写字识别

介绍

这段代码是使用PyTorch实现的卷积神经网络(CNN),用于在MNIST数据集上进行图像分类。让我一步步解释:

  1. 导入库:代码导入了必要的库,包括PyTorch(torch)、神经网络模块(torch.nn)、函数模块(torch.nn.functional)、图像数据集(torchvision)以及数据处理(torch.utils.data)和可视化(matplotlib.pyplot)的工具。

  2. 设置超参数:定义了超参数,如批大小(Batch_size)、epoch数量(Epoch)和学习率(Lr)。

  3. 加载MNIST数据集:使用torchvision.datasets.MNIST加载MNIST数据集。该数据集包含了0到9的手写数字的灰度图像。transform=torchvision.transforms.ToTensor()将PIL图像转换为PyTorch张量。

  4. 可视化样本数据:打印数据集的大小,并显示数据集中的第一张图像及其相应的标签。

  5. 准备测试数据:准备测试数据与训练数据类似。加载MNIST测试数据集,并选择前2000个图像进行测试。

  6. 创建数据加载器:使用torch.utils.data.DataLoader创建训练数据的数据加载器。它有助于在训练过程中对数据进行分批和混洗。

  7. 定义CNN架构:通过子类化nn.Module来定义CNN类。该架构包括两个卷积层(self.con1self.con2),后面跟有ReLU激活函数和最大池化层。卷积层的输出被展平并馈入全连接层(self.out),产生最终输出。

  8. 初始化CNN:创建CNN类的实例。

  9. 定义损失函数和优化器:使用交叉熵损失(nn.CrossEntropyLoss)作为损失函数,使用随机梯度下降(torch.optim.SGD)作为优化器。

  10. 训练CNN:在指定的epoch数量循环内训练模型。在循环内,将训练数据通过模型,计算损失,进行梯度反向传播,并由优化器更新模型参数。

  11. 测试模型:每50次迭代训练时,对测试数据集进行评估。将测试预测与真实标签进行比较,计算准确率。

  12. 打印结果:训练结束后,打印模型预测及前10个测试样本的真实标签。

总的来说,这段代码训练了一个CNN模型,用于在MNIST数据集上对手写数字进行分类,并在单独的测试数据集上评估其性能。

代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.utils.data as Data
import matplotlib.pyplot as plt

# define hyper parameters
Batch_size = 100
Epoch = 1
Lr = 0.5
#DOWNLOAD_MNIST = True # 若没有数据,用此生成数据

# define train data and test data
train_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=True,
    download=False,
    transform=torchvision.transforms.ToTensor()
)
print(train_data.data.size())
print(train_data.targets.size())
print(train_data.data[0])
# 画一个图片显示出来
plt.imshow(train_data.data[0].numpy(),cmap='gray')
plt.title('%i'%train_data.targets[0])
plt.show()
# print(train_data.data.shape)           # torch.Size([60000, 28, 28])
# print(train_data.targets.size())        # torch.Size([60000])
# print(train_data.data[0].size())        # torch.Size([28, 28])
# plt.imshow(train_data.data[0].numpy(), cmap='gray')
# plt.show()
test_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=False,
    # transform=torchvision.transforms.ToTensor()
)
test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000]
test_y = test_data.targets[:2000]
# print(test_x.shape)                         # torch.Size([2000, 1, 28, 28])
# print(test_y.shape)                         # torch.Size([2000])
train_loader = Data.DataLoader(
    dataset=train_data,
    shuffle=True,
    batch_size=Batch_size,
)

# define network structure
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.con1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.con2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = self.con1(x)            # (batch, 16, 14, 14)
        x = self.con2(x)            # (batch, 32, 7, 7)
        x = x.view(x.size(0), -1)
        out = self.out(x)             # (batch_size, 10)
        return out

cnn = CNN()
# print(cnn)
optimizer = torch.optim.SGD(cnn.parameters(), lr=Lr)
loss_fun = nn.CrossEntropyLoss()

for epoch in range(Epoch):
    for i, (x, y) in enumerate(train_loader):
        output = cnn(x)
        loss = loss_fun(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            test_output = torch.max(cnn(test_x), dim=1)[1]
            loss = loss_fun(cnn(test_x), test_y).item()
            accuracy = torch.sum(torch.eq(test_output, test_y)).item() / test_y.numpy().size
            print('Epoch:', Epoch, '|loss:%.4f' % loss, '|accuracy:%.4f' % accuracy)

print('real value', test_data.targets[: 10].numpy())
print('train value', torch.max(cnn(test_x)[: 10], dim=1)[1].numpy())




结果

real value [7 2 1 0 4 1 4 9 5 9]
train value [7 2 1 0 4 1 4 9 5 9]

 

  • 10
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: 机器学习基于CNN手写识别是一项实验性研究,该技术可以通过训练算法和图像数据集,自动识别分类手写。下面我将用300向您介绍相关实验过程。 实验的第一步是准备数据集,可以使用MNIST数据集,该数据集包含大量手写图像样本。然后,我们将数据集分为训练集和测试集,用于训练和评估模型性能。 接下来,我们使用CNN模型进行手写识别的训练。CNN(卷积神经网络)是一种常用的深度学习模型,特别适用于图像识别。该模型可以自动提取图像中的特征,并进行分类。我们通过不断调整模型的结构和参数,让其能够更好地适应手写识别任务。 训练过程中,我们将训练集的图像输入到CNN模型中,模型通过反向传播算法不断调整权重和偏置,以最小化预测结果与实际标签之间的误差。随着训练的进行,模型逐渐优化,使其在测试集上的准确度得到提升。 完成训练后,我们将使用测试集对模型进行评估。通过与实际标签对比,可以计算出模型的准确率、精确度、召回率等性能指标,从而评估模型的表现。 最后,我们可以使用训练好的模型进行手写识别。将手写图像输入到模型中,模型将自动输出识别结果。 这项实验的目的是将机器学习CNN技术应用于手写识别,提高识别的准确度和速度。它在人工智能、图像处理等领域具有广泛的应用前景,可以为我们提供更多便利和智能化的服务。 ### 回答2: 机器学习是一种可以让计算机通过学习和训练数据来完成特定任务的方法。而基于CNN(卷积神经网络)的手写识别实验即利用机器学习的方法来实现对手写的自动识别。 首先,我们需要准备一个包含大量手写的数据集,这些数据集中既包含手写图片,也包含对应的标签。在该实验中,我们需要将每个手写图片与其对应的数标签建立联系。 接下来,我们可以利用CNN模型来训练和优化识别手写的算法。CNN是一种专门应用于图像处理和识别的深度学习模型。通过分析手写图片中不同的特征和模式,CNN可以学习到一种有效的表示手写的方式。 在训练过程中,我们将数据集划分为训练集和验证集,用于训练和评估模型的性能。通过迭代训练,自动调整模型的参数和权重,使其逐渐提高识别手写的准确率。 完成训练后,我们可以用测试集来评估模型的性能。测试集是一个模型从未见过的数据集,用于模拟实际应用场景。通过与标签比较,我们可以计算出模型在测试集上的准确率,来评判其对手写识别的能力。 最后,我们可以使用训练好的CNN模型来进行实际的手写识别。输入一张手写图片,经过模型的处理和分析,输出对应的数。 综上所述,基于CNN手写识别实验利用机器学习的方法训练和优化模型,以实现自动识别手写的功能。该实验将深度学习和图像处理的技术应用于手写识别,具有较高的准确率和广泛的应用前景。 ### 回答3: 机器学习是一种能够通过训练模型来让计算机对数据进行自动学习的技术。基于卷积神经网络(Convolutional Neural Network,CNN)的手写识别机器学习的一项实验。 首先,为了进行手写识别实验,我们需要准备一个手写的数据集。这个数据集包含了许多手写的图片,每张图片都有对应的标签,表示图片所代表的数。 接着,我们将利用CNN来构建一个模型。CNN是一种深度学习架构,它能够提取图像的特征并用于分类任务。CNN通常由多个卷积层、池化层和全连接层组成。在手写识别实验中,我们可以设计一个具有几个卷积层和全连接层的CNN模型。 然后,我们需要将数据集分为训练集和测试集。训练集用于训练模型,测试集用于评估模型的性能。在训练过程中,模型会根据训练集的数据不断调整自身的参数,以使其能够更好地对手写进行识别。训练的过程中需要定义损失函数和优化器来指导模型的参数更新。 训练完成后,我们将使用测试集来评估模型的性能。评估指标可以是准确率,即模型正确预测手写的比例。较高的准确率表示模型对手写识别能力较强。 此外,为了提高模型的性能,我们还可以采取一些策略,如数据增强、超参数调节等。数据增强可以通过对训练集进行旋转、平移、缩放等操作,生成更多的训练样本,以增加模型的泛化能力。超参数调节可以通过调整模型的学习率、批大小等参数,以找到更好的模型配置。 通过这样的实验,我们可以验证基于CNN手写识别模型的效果,并探索机器学习在图像识别任务中的应用潜力。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小张er

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值