Pytorch 手写数字识别2

前言:

       这里主要结合手写数字识别训练,验证过程,简单了解一下

Pytorch 主要应用的API函数,跟Numpy 不同,这里面有很多API

可以自动计算微分,梯度更新 等

参考:

CSDN

 


一  实现效果

    1.1  梯度更新过程

        

 

    1.2 测试集验证结果

          test  acc:   0.912      total_num:  10000

 

     1.3 预测图像显示

        下面为实际图像,上面为预测值

     

 


二  训练验证过程

  

   

# -*- coding: utf-8 -*-
"""
Created on Mon Nov 28 16:09:06 2022

@author: chengxf2
"""
import torch
import torch.optim as optim 
from torch import nn
import torch.nn.functional  as F
from down_data import load_pic
from util import one_hot
from util import plot_curve
from util import plot_image




class  Net(nn.Module):
    
    
    def  __init__(self):
        
        super(Net, self).__init__()
        
        #xW^T+b
        self.fc1 =  nn.Linear(28*28,256) 
        self.fc2 =  nn.Linear(256, 64)
        self.fc3 =  nn.Linear(64, 10)
        self.maxIter = 2
        
    
    '''
    前向传播
    '''
    def forward(self,x):
        
        #x: [m,28*28] 图片个数m
        #H1 =xW^T+b
        h1 = self.fc1(x)
        a1  =F.relu(h1)
        
        #h2 = xW^T+b
        h2 = self.fc2(a1)
        a2 = F.relu(h2)
        
        
        #h3 = h2w3+b3
        h3 = self.fc3(a2)
        #dim=0代表是列,dim=1代表是行 
        a3 = F.softmax(h3, dim=1)
        
        return a3
    
'''
训练模型
args:
    train_loader: 训练的数据集
'''  
def train(train_loader):
        
        #w1, b1, w2 b2, w3 b3 3层神经网络结构
        net = Net()
        optimizer = optim.SGD(net.parameters(),lr=0.01)
        maxIter = 20
        
        trainLoss = []
        for epoch in range(maxIter): #对数据集进行递归
            for batch_idx, (x,y) in  enumerate(train_loader):

                optimizer.zero_grad()
                #print(x.shape, y.shape) #[m,1,28,28] 实际为4维
                x =x.view(x.size(0),28*28) #图片维度切换[1,28*28]
               
                out = net(x) 
    
   
                y_onehot = one_hot(y)#
                
                loss = F.cross_entropy(out,y_onehot)
                loss.backward()  #计算梯度
                optimizer.step() #梯度更新 w= w-lr*grad
                
                trainLoss.append(loss.item()) #保存梯度
                if batch_idx%1000 ==0:
                    print("epoch:%d  batch_idx: %d   loss: %7.4f"%(epoch, batch_idx, loss.item()))
        #[w1,b1,w2,b2,w3,b3]
        plot_curve(trainLoss)
        return net
        
 
# 使用测试机来验证
def  verify(test_loader,net):
     total_correct = 0
    
     
     for x,y in test_loader:
         N = x.size(0) #样本个数
         x = x.view(N,28*28)
         out = net(x) #[50,10]
         #print("\n  out: ",out,out.shape) #[50, 10]
         pred = out.argmax(dim = 1)  #torch.Size([50]
         correct = pred.eq(y).sum().float().item() 
         total_correct += correct
     
     total_num = len(test_loader.dataset)
     acc = total_correct/total_num

     print("test  acc: %7.3f"%acc,"\t total_num: ",total_num)      



# 图形化显示
def verify_show(test_loader,net):

     x,y = next(iter(test_loader)) #单独取一个batch
     
     N = x.size(0)#N=1 ,1, 1, 28, 28
     print("\n N: %d"%N,x.shape, x.type) #[50, 10]
     X = x.view(N,28*28) #N 取决于batch_size
     out = net(X)
     #out [N,10] ==> pred: [N]
     pred = out.argmax(dim=1)   
     plot_image(x,pred,'predict: ')               
            

if  __name__ == "__main__":
    
    
    train_loader,test_loader = load_pic()
    
    net = train(train_loader)
    verify(test_loader, net)
    #verify_show(test_loader, net)
    
    
            
            
    
        
        

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值