【深度学习与神经网络】MNIST手写数字识别1

本文介绍了如何使用PyTorch构建一个简单的全连接神经网络,应用于MNIST手写数字识别任务。网络结构包含一层全连接层,使用MSE损失函数和SGD优化器进行训练,但最终测试准确率不高。
摘要由CSDN通过智能技术生成

简单的全连接层

导入相应库

import torch
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

读入数据并转为tensor向量

# 训练集
# 转为tensor数据
train_dataset = datasets.MNIST(root='./',train=True, transform = transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./',train=False, transform = transforms.ToTensor(), download=True)

装载数据集

# 批次大小
batch_size = 64

# 装载训练集
train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size=batch_size, shuffle = True)

定义网络结构
一层全连接网络,最后使用softmax转概率值输出

# 定义网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 10)
        self.softmax = nn.Softmax(dim =1)
        
    def forward(self, x):
        # [64,1,28,28] ——> [64, 784]
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.softmax(x)
        return x   

定义模型
使用均方误差损失函数,梯度下降优化

# 定义模型
model = Net()
mes_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),0.5)

训练并测试网络:
训练时注意最后输出(64,10)
标签是(64) ,需要将其转为one-hot编码(64,10)

def train():
    for i,data in enumerate(train_loader):
        # 获得一个批次的数据和标签
        inputs, labels = data
        # 获得模型结果 (64,10)
        out = model(inputs)
        # to one-hot 把数据标签变为独热编码
        labels = labels.reshape(-1,1)
        one_hot = torch.zeros(inputs.shape[0],10).scatter(1, labels, 1)
        # 计算loss
        loss = mes_loss(out, one_hot)
        # 梯度清0
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()
        
def test():
    correct = 0
    for i,data in enumerate(test_loader):
        # 获得一个批次的数据和标签
        inputs, labels = data
        # 获得模型结果 (64,10)
        out = model(inputs)
        # 获取最大值和最大值所在位置
        _,predicted = torch.max(out,1)
        # 预测正确数量
        correct += (predicted == labels).sum()
        
        
    print("test ac:{0}".format(correct.item()/len(test_dataset)))
        
  

调用模型 训练10次

# 使用mse损失函数 
for epoch in range(10):
    print("epoch:",epoch)
    train()
    test()

训练结果:
在这里插入图片描述
准确率不够

  • 15
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值