莫凡pytorch mnist手写数据识别

mnist手写数据识别 CNNC训练

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

#hyper parameters超参数
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False

# 下载MNIST
train_data = torchvision.datasets.MNIST(
    root='./mnist',         #保存在文件夹中
    train=True,         #false则是测试点test
    transform=torchvision.transforms.ToTensor(),     #tensor数据的值在0-1之间 ,彩的0-255压缩到黑白的0-1
    download=DOWNLOAD_MNIST
    )
# #plot one example
# print(train_data.train_data.size())          #60000,28,28
# print(train_data.train_labels.size())      #60000
# plt.imshow(train_data.train_data[0].numpy(),cmap='gray')
# plt.title('%i'%train_data.train_labels[0])
# plt.show()                                           ctrl+/ duohangzhushi

train_loader = Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True,num_workers=2)

test_data = torchvision.datasets.MNIST(root='./mnist/',train=False)
test_x = Variable(torch.unsqueeze(test_data.test_data,dim=1),volatile=True).type(torch.FloatTensor)[:2000]/255    #shape
test_y = test_data.test_labels[:2000]   #取前1000个标签y

#构建网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(                          #卷积层          (28,28,1)
                in_channels=1,               #图片的高度,黑白1层
                out_channels=16,          #输出的高度,过滤器的个数
                kernel_size=5,               #过滤器是5*5*16
                stride=1,                  #步长 1
                padding=2                  #填充 0   如果步长是1,则padding=(kernal_size-1)/2, 为了保证与前面图片大小相等
            ),                                                 # (28,28,16)
            nn.ReLU(),           #jihuohanshu
            nn.MaxPool2d(kernel_size=2),        #池化层    2*2  高度不变   (14,14,16)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16,32,5,1,2),               #加工成32层      (14,14,32)
            nn.ReLU(),
            nn.MaxPool2d(2),                                       #(7,7,32)
        )
        self.out = nn.Linear(32*7*7,10)        #将三维的数据展评成二维的数据

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)       #batch,327,7
        x = x.view(x.size(0),-1)      # batch,32*7*7
        output = self.out(x)          #输出10
        return output
cnn = CNN()
#print(cnn)

#训练网络
optimizer = torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()

for epoch in range(EPOCH):
    for step, (x, y) in enumerate(train_loader):
        b_x = Variable(x)
        b_y = Variable(y)            # gives batch data, normalize x when iterate train_loader

        output = cnn(b_x)               # cnn output
        loss = loss_func(output, b_y)   # cross entropy loss
        optimizer.zero_grad()           # clear gradients for this training step
        loss.backward()                 # backpropagation, compute gradients
        optimizer.step()                # apply gradients


        if step % 50 ==0:
            test_output = cnn(test_x)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            accuracy = sum(pred_y == test_y) / test_y.size(0)
            print('Epoch:',epoch,'| train loss: %.4f' % loss.item())

# print 10 predictions from test data
test_output = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
print(pred_y, 'prediction number')
print(test_y[:10].numpy(), 'real number')






  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
资源包主要包含以下内容: ASP项目源码:每个资源包中都包含完整的ASP项目源码,这些源码采用了经典的ASP技术开发,结构清晰、注释详细,帮助用户轻松理解整个项目的逻辑和实现方式。通过这些源码,用户可以学习到ASP的基本语法、服务器端脚本编写方法、数据库操作、用户权限管理等关键技术。 数据库设计文件:为了方便用户更好地理解系统的后台逻辑,每个项目中都附带了完整的数据库设计文件。这些文件通常包括数据库结构图、数据表设计文档,以及示例数据SQL脚本。用户可以通过这些文件快速搭建项目所需的数据库环境,并了解各个数据表之间的关系和作用。 详细的开发文档:每个资源包都附有详细的开发文档,文档内容包括项目背景介绍、功能模块说明、系统流程图、用户界面设计以及关键代码解析等。这些文档为用户提供了深入的学习材料,使得即便是从零开始的开发者也能逐步掌握项目开发的全过程。 项目演示与使用指南:为帮助用户更好地理解和使用这些ASP项目,每个资源包中都包含项目的演示文件和使用指南。演示文件通常以视频或图文形式展示项目的主要功能和操作流程,使用指南则详细说明了如何配置开发环境、部署项目以及常见问题的解决方法。 毕业设计参考:对于正在准备毕业设计的学生来说,这些资源包是绝佳的参考材料。每个项目不仅功能完善、结构清晰,还符合常见的毕业设计要求和标准。通过这些项目,学生可以学习到如何从零开始构建一个完整的Web系统,并积累丰富的项目经验。
资源包主要包含以下内容: ASP项目源码:每个资源包中都包含完整的ASP项目源码,这些源码采用了经典的ASP技术开发,结构清晰、注释详细,帮助用户轻松理解整个项目的逻辑和实现方式。通过这些源码,用户可以学习到ASP的基本语法、服务器端脚本编写方法、数据库操作、用户权限管理等关键技术。 数据库设计文件:为了方便用户更好地理解系统的后台逻辑,每个项目中都附带了完整的数据库设计文件。这些文件通常包括数据库结构图、数据表设计文档,以及示例数据SQL脚本。用户可以通过这些文件快速搭建项目所需的数据库环境,并了解各个数据表之间的关系和作用。 详细的开发文档:每个资源包都附有详细的开发文档,文档内容包括项目背景介绍、功能模块说明、系统流程图、用户界面设计以及关键代码解析等。这些文档为用户提供了深入的学习材料,使得即便是从零开始的开发者也能逐步掌握项目开发的全过程。 项目演示与使用指南:为帮助用户更好地理解和使用这些ASP项目,每个资源包中都包含项目的演示文件和使用指南。演示文件通常以视频或图文形式展示项目的主要功能和操作流程,使用指南则详细说明了如何配置开发环境、部署项目以及常见问题的解决方法。 毕业设计参考:对于正在准备毕业设计的学生来说,这些资源包是绝佳的参考材料。每个项目不仅功能完善、结构清晰,还符合常见的毕业设计要求和标准。通过这些项目,学生可以学习到如何从零开始构建一个完整的Web系统,并积累丰富的项目经验。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值