Pytorch解决分类问题

获取数据

import torch 
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
train_data=datasets.MNIST(root="./data/MNIST",train=True,transform=transforms.ToTensor(),download=False)
test_data=datasets.MNIST(root="./data/MNIST",train=False,transform=transforms.ToTensor(),download=False)
train_data
Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data/MNIST
    Split: Train
    StandardTransform
Transform: ToTensor()
train_loader=torch.utils.data.DataLoader(dataset=train_data,batch_size=64,shuffle=True)
test_loader=torch.utils.data.DataLoader(dataset=test_data,batch_size=64,shuffle=True)

构建模型

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(1,32,kernel_size=3,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc=nn.Linear(14*14*32,10)
    def forward(self,x):
        out=self.conv(x)
        out=out.view(out.size()[0],-1)
        out=self.fc(out)
        return out
cnn=CNN()

定义损失函数和优化器

loss_func=nn.CrossEntropyLoss()
optimizer=optim.Adam(cnn.parameters(),lr=0.001)

模型训练

epochs=10
import time
t1=time.time()
for epoch in range(epochs):
    for i,(images,labels) in enumerate(train_loader):
        outputs=cnn(images)
        loss=loss_func(outputs,labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i+1)%100==0:
            print(f"epoch:{epoch+1},iter:{i}/{len(train_data)//64},loss_train:{loss}")
    loss_test=0.0
    acc=0.0
    for i,(images,labels) in enumerate(test_loader):
        outputs=cnn(images)
        loss_test+=loss_func(outputs,labels)
        _,pred=outputs.max(1)
        acc+=(pred==labels).sum().item()
    acc=acc/len(test_data)
    loss_test=loss_test/(len(test_data)//64)
    print(f"epoch:{epoch+1},acc:{acc},loss_test:{loss_test}")
print("training time is:",(time.time()-t1)/60,"min")
epoch:1,iter:99/937,loss_train:0.16643951833248138
epoch:1,iter:199/937,loss_train:0.18687975406646729
epoch:1,iter:299/937,loss_train:0.08141929656267166
epoch:1,iter:399/937,loss_train:0.050456926226615906
epoch:1,iter:499/937,loss_train:0.047815751284360886
epoch:1,iter:599/937,loss_train:0.07874736934900284
epoch:1,iter:699/937,loss_train:0.06673698872327805
epoch:1,iter:799/937,loss_train:0.17430831491947174
epoch:1,iter:899/937,loss_train:0.07649435847997665
epoch:1,acc:0.9709,loss_test:0.0881509855389595
epoch:2,iter:99/937,loss_train:0.04289521649479866
epoch:2,iter:199/937,loss_train:0.04798528924584389
epoch:2,iter:299/937,loss_train:0.16610397398471832
epoch:2,iter:399/937,loss_train:0.011500637978315353
epoch:2,iter:499/937,loss_train:0.07874076068401337
epoch:2,iter:599/937,loss_train:0.03385515138506889
epoch:2,iter:699/937,loss_train:0.025212256237864494
epoch:2,iter:799/937,loss_train:0.043671611696481705
epoch:2,iter:899/937,loss_train:0.05083635821938515
epoch:2,acc:0.9785,loss_test:0.06763216853141785
epoch:3,iter:99/937,loss_train:0.0809536948800087
epoch:3,iter:199/937,loss_train:0.12687648832798004
epoch:3,iter:299/937,loss_train:0.023435810580849648
epoch:3,iter:399/937,loss_train:0.02110959030687809
epoch:3,iter:499/937,loss_train:0.030197784304618835
epoch:3,iter:599/937,loss_train:0.19537533819675446
epoch:3,iter:699/937,loss_train:0.10126397013664246
epoch:3,iter:799/937,loss_train:0.014641757123172283
epoch:3,iter:899/937,loss_train:0.07761335372924805
epoch:3,acc:0.9802,loss_test:0.07316284626722336
epoch:4,iter:99/937,loss_train:0.008765117265284061
epoch:4,iter:199/937,loss_train:0.14132143557071686
epoch:4,iter:299/937,loss_train:0.013033450581133366
epoch:4,iter:399/937,loss_train:0.025654081255197525
epoch:4,iter:499/937,loss_train:0.07331296056509018
epoch:4,iter:599/937,loss_train:0.05696120858192444
epoch:4,iter:699/937,loss_train:0.20584994554519653
epoch:4,iter:799/937,loss_train:0.017348136752843857
epoch:4,iter:899/937,loss_train:0.056173764169216156
epoch:4,acc:0.9821,loss_test:0.055867232382297516
epoch:5,iter:99/937,loss_train:0.03966124728322029
epoch:5,iter:199/937,loss_train:0.011977104470133781
epoch:5,iter:299/937,loss_train:0.013878636062145233
epoch:5,iter:399/937,loss_train:0.00939354207366705
epoch:5,iter:499/937,loss_train:0.014462064020335674
epoch:5,iter:599/937,loss_train:0.010698058642446995
epoch:5,iter:699/937,loss_train:0.045422524213790894
epoch:5,iter:799/937,loss_train:0.014188112691044807
epoch:5,iter:899/937,loss_train:0.018470803275704384
epoch:5,acc:0.9822,loss_test:0.05920825153589249
epoch:6,iter:99/937,loss_train:0.014892310835421085
epoch:6,iter:199/937,loss_train:0.014666389673948288
epoch:6,iter:299/937,loss_train:0.010755248367786407
epoch:6,iter:399/937,loss_train:0.03675413131713867
epoch:6,iter:499/937,loss_train:0.10592899471521378
epoch:6,iter:599/937,loss_train:0.01673516258597374
epoch:6,iter:699/937,loss_train:0.0003224133397452533
epoch:6,iter:799/937,loss_train:0.018152574077248573
epoch:6,iter:899/937,loss_train:0.08764231950044632
epoch:6,acc:0.9808,loss_test:0.06770947575569153
epoch:7,iter:99/937,loss_train:0.0032583947759121656
epoch:7,iter:199/937,loss_train:0.017065133899450302
epoch:7,iter:299/937,loss_train:0.010818807408213615
epoch:7,iter:399/937,loss_train:0.01183609664440155
epoch:7,iter:499/937,loss_train:0.00855921022593975
epoch:7,iter:599/937,loss_train:0.05112480744719505
epoch:7,iter:699/937,loss_train:0.006859672721475363
epoch:7,iter:799/937,loss_train:0.0020741026382893324
epoch:7,iter:899/937,loss_train:0.034429844468832016
epoch:7,acc:0.9797,loss_test:0.07257354259490967
epoch:8,iter:99/937,loss_train:0.002220737747848034
epoch:8,iter:199/937,loss_train:0.0025882022455334663
epoch:8,iter:299/937,loss_train:0.025042632594704628
epoch:8,iter:399/937,loss_train:0.05414767563343048
epoch:8,iter:499/937,loss_train:0.0008554743253625929
epoch:8,iter:599/937,loss_train:0.0037804560270160437
epoch:8,iter:699/937,loss_train:0.003353153821080923
epoch:8,iter:799/937,loss_train:0.09923534095287323
epoch:8,iter:899/937,loss_train:0.017603596672415733
epoch:8,acc:0.9809,loss_test:0.06367961317300797
epoch:9,iter:99/937,loss_train:0.0033164226915687323
epoch:9,iter:199/937,loss_train:0.003846828592941165
epoch:9,iter:299/937,loss_train:0.00036157420254312456
epoch:9,iter:399/937,loss_train:0.03678906708955765
epoch:9,iter:499/937,loss_train:0.1933242678642273
epoch:9,iter:599/937,loss_train:0.021390235051512718
epoch:9,iter:699/937,loss_train:0.02964860387146473
epoch:9,iter:799/937,loss_train:0.0059194485656917095
epoch:9,iter:899/937,loss_train:0.00673449644818902
epoch:9,acc:0.9822,loss_test:0.0644088163971901
epoch:10,iter:99/937,loss_train:0.0004312991804908961
epoch:10,iter:199/937,loss_train:0.0002770039136521518
epoch:10,iter:299/937,loss_train:0.018510431051254272
epoch:10,iter:399/937,loss_train:0.01306982897222042
epoch:10,iter:499/937,loss_train:0.0022419721353799105
epoch:10,iter:599/937,loss_train:0.0007191193872131407
epoch:10,iter:699/937,loss_train:0.01564517244696617
epoch:10,iter:799/937,loss_train:0.0059061129577457905
epoch:10,iter:899/937,loss_train:0.077512226998806
epoch:10,acc:0.982,loss_test:0.06344469636678696
training time is: 11.6351833264033 min

模型保存

import os
os.makedirs("checkpoints/classify/")
torch.save(cnn,"checkpoints/classify/simple_cnn_v1.pkl")

推理

cnn=torch.load("checkpoints/classify/simple_cnn_v1.pkl")
for i,(images,labels) in enumerate(test_loader):
    outputs=cnn(images)
    _,pred=outputs.max(1)
    acc+=(pred==labels).sum().item()
acc=acc/len(test_data)
print(f"acc:{acc}")
acc:0.9816982
  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值