简单的线性分类器训练MNIST(Pytorch基础练习)

完整文件:https://github.com/JintuZheng/Blog-/blob/master/Demo_LogicRegression_MNIST.py

包导入准备

import torchvision.datasets
import torchvision.transforms
import torch.utils.data
import matplotlib.pyplot as plt
import torch.nn
import torch.optim

from debug import ptf_tensor

设置超参数

# Hyperparameters超参数
BATCH_SIZE=100
NUM_EPOCHS=5
DEVICE='cuda:0'

数据集下载

########################## 训练集的准备 ##############################################

train_dataset=torchvision.datasets.MNIST(root='D:/DataTmp/mnist',train=True,transform=torchvision.transforms.ToTensor(),download=True)
#root:下载数据存放到哪里,train:下载训练集还是测试集,transfrom:数据转化的形式

test_dataset=torchvision.datasets.MNIST(root='D:/DataTmp/mnist',train=False, transform=torchvision.transforms.ToTensor(),download=True)

【1】设置dataloader,分批读取数据,因为我们没办法一次训练过多数据

#由于数据集里面有上万条数据,我们需要分批从数据集读取数据
train_dataloader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE)
print('The len of train dataset={}'.format(len(train_dataset)))

test_dataloader=torch.utils.data.DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE)
print('The len of test dataset={}'.format(len(test_dataset)))

【2】查看数据格式

for images,labels in train_dataloader:
    print('The images size is {}',format(images.size())) 
    print('The labels size is {}'.format(labels.size())) 
    break #本循环就是执行一次

线性分类器准备

【1】构建一层的分类器

fc=torch.nn.Linear(28*28,10) #只使用一层线性分类器
fc.to(DEVICE)#如果用CPU去掉

【2】构建损失函数

criterion=torch.nn.CrossEntropyLoss()

【3】根据假设函数的参数构建优化器

optimizer=torch.optim.Adam(fc.parameters())

开始迭代训练

for epoch in range(NUM_EPOCHS):
    for idx, (images,labels) in enumerate(train_dataloader):
        x =images.reshape(-1,28*28)

        x=x.to(DEVICE)# 如果用CPU去掉
        labels=labels.to(DEVICE)# 如果用CPU去掉

        optimizer.zero_grad() #梯度清零
        preds=fc(x) #计算预测
        loss=criterion(preds,labels) #计算损失
        loss.backward() # 计算参数梯度
        optimizer.step() # 更新迭代梯度

        if idx % 100 ==0:
            print('epoch={}:idx={},loss={:g}'.format(epoch,idx,loss))

检验最后的正确率

correct=0
total=0

for idx,(images,labels) in enumerate(test_dataloader):
    x =images.reshape(-1,28*28) #对所有的图片进行reshape size(m,28*28)
    x=x.to(DEVICE)
    labels=labels.to(DEVICE)

    preds=fc(x)
    predicted=torch.argmax(preds,dim=1) #在dim=1中选取max值的索引
    if idx ==0:
        print('x size:{}'.format(x.size()))
        print('preds size:{}'.format(preds.size()))
        print('predicted size:{}'.format(predicted.size()))

    total+=labels.size(0)
    correct+=(predicted == labels).sum().item()
    #print('##########################\nidx:{}\npreds:{}\nactual:{}\n##########################\n'.format(idx,predicted,labels))

accuracy=correct/total
print('{:1%}'.format(accuracy))

参数数据的保存和复原

#保存
torch.save(fc.state_dict(), 'D:/DataTmp/mnist/tst.pth')
fc=torch.nn.Linear(28*28,10) #只使用一层线性分类器
#复原
fc.to(DEVICE)#如果用CPU去掉
fc.load_state_dict(torch.load('D:/DataTmp/mnist/tst.pth'))
  • 6
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值