《Pytorch深度学习》第5章——MNIST手写体识别(CNN)

模型结构及各层Tensor的shape

模型结构

各层Tensor的shape

新版torch的代码实现及训练准确度

# -*- coding:utf-8 -*-
"""
author: zliu.elliot
@time: 2021-08-26 17:
@file: mnistWithCNN.py
"""
import torch.optim
from torchvision.transforms import transforms
from torchvision import datasets

from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as functional


def fit(epoch, model, dataloader, criterion, optimizer, phase="train"):
    if phase == "train":
        model.train()
    else:
        model.eval()
    run_loss = 0
    run_correct = 0
    for batch_idx, (data, target) in enumerate(dataloader):
        if phase == "train":
            optimizer.zero_grad()
        output = model(data)
        predict = output.data.max(dim=1, keepdim=True)[1]

        loss = criterion(output, target)
        run_loss += functional.nll_loss(output, target, reduction='sum').item()

        predict = predict.reshape(target.shape)
        run_correct += predict.eq(target).sum().item()

        if phase == "train":
            loss.backward()
            optimizer.step()

    loss = run_loss / len(dataloader.dataset)
    accuracy = float(run_correct) / len(dataloader.dataset)
    print(f"[{epoch}]{phase} loss is {loss:.3f}, accuracy is {accuracy:.3f}%")
    return loss, accuracy


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root="./mnistDataset/", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./mnistDataset/", train=False, transform=transform, download=True)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

net = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5),
    nn.MaxPool2d(kernel_size=2),
    nn.ReLU(),
    nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5),
    nn.Dropout2d(),
    nn.MaxPool2d(kernel_size=2),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(320, 50),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(50, 10),
    nn.LogSoftmax()
)

loss_fn = nn.NLLLoss()
optim = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
train_loss, train_acc = [], []
val_loss, val_acc = [], []


for epoch in range(20):
    epoch_loss, epoch_acc = fit(epoch, net, train_dataloader, loss_fn, optim)
    train_loss.append(epoch_loss), train_acc.append(epoch_acc)

    epoch_loss, epoch_acc = fit(epoch, net, test_dataloader, loss_fn, optim, phase="validation")
    val_loss.append(epoch_loss), val_acc.append(epoch_acc)

print(train_loss)
print(train_acc)

print(val_loss)
print(val_acc)

训练准确度

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值