Pytorch极简入门教程(十)—— 手写数字识别

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

"""
torchvision内置了常用数据集和常见的模型
"""
import torchvision
# datasets 放数据集  transforms 转换数据集
from torchvision import datasets, transforms
# transforms.ToTensor (1)转化为一个tensor (2)转换到0-1之间 (3)会将channel放在第一维度上

transformation = transforms.Compose([transforms.ToTensor()
                                     ])
train_ds = datasets.MNIST("./data", train=True, transform=transformation, download=True)

test_ds = datasets.MNIST("./data", train=False, transform=transformation, download=True)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256)

imgs, labels = next(iter(train_dl))
print("imgs.shape:\t", imgs.shape)

"""
在pytorch里面图片的表示形式:[batch, channel, hight, width]
"""

img = imgs[0]
print("img.shape:\t", img.shape)

img = img.numpy()
print("img.shape:\t", img.shape)
# squeeze() 从数组的形状中删除单维度条目,即把shape中未1的维度去掉
img = np.squeeze(img)
print("np.squeeze(img):\t", img)
print("img.shape:\t", img.shape)

plt.imshow(img)
plt.show()
# 查看第一张图片的标签
print("labels[0]:\t", labels[0])

def imshow(img):
    npimg = img.numpy()
    npimg = np.squeeze(npimg)
    plt.imshow(npimg)
plt.figure(figsize=(10, 1))
for i, img in enumerate(imgs[:10]):
    plt.subplot(1, 10, i+1)
    imshow(img)
plt.show()

# 创建模型
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.liner_1 = nn.Linear(784, 120)
        self.liner_2 = nn.Linear(120, 84)
        self.liner_3 = nn.Linear(84, 10)
    def forward(self, input):
        x = input.view(-1, 784)
        x = F.relu(self.liner_1(x))
        x = F.relu(self.liner_2(x))
        x = self.liner_3(x)
        return x

loss_fn = torch.nn.CrossEntropyLoss()
model = Model()


# 编码一个fit函数,对输入模型、输入数据(train_dl, test_dl),对数据输入在模型上训练,并且返回loss和acc变化
def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    for x, y in trainloader:
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optim.zero_grad()
        loss.backward()
        optim.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / total

    test_correct = 0
    test_total = 0
    test_running_loss = 0

    with torch.no_grad():
        for x, y in testloader:
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)

            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()

    epoch_test_loss = running_loss / len(trainloader.dataset)
    epoch_test_acc = correct / total

    print("epoch:\t{} loss:\t{} accuracy:\t{} test_loss:\t{} test_accuracy:\t{}".format(epoch, round(epoch_loss, 3),
                                                                                        round(epoch_acc, 3),
                                                                                        round(epoch_test_loss, 3),
                                                                                        round(epoch_test_acc, 3)))
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

optim = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 20

train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch, model, train_dl, test_dl)

    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_loss.append(epoch_test_acc)

imgs.shape:	 torch.Size([64, 1, 28, 28])
img.shape:	 torch.Size([1, 28, 28])
img.shape:	 (1, 28, 28)
np.squeeze(img):	 [[0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.09803922 0.69411767 0.61960787 0.8784314  0.39607844
  0.00784314 0.1254902  0.7490196  0.04705882 0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.14509805
  0.6901961  0.9137255  0.9882353  0.99215686 0.9882353  0.9882353
  0.15294118 0.33333334 0.9882353  0.27450982 0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.2901961  0.9019608
  0.9882353  0.9882353  0.9882353  0.6039216  0.6        0.6
  0.5254902  0.92941177 0.9882353  0.08235294 0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.34117648 0.68235296 0.9882353
  0.972549   0.5803922  0.09411765 0.         0.         0.07450981
  0.5568628  0.8745098  0.32156864 0.00392157 0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.05882353 0.87058824 0.9882353  0.77254903
  0.28235295 0.         0.         0.03137255 0.33333334 0.92941177
  0.9882353  0.6784314  0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.39215687 0.9882353  0.7176471  0.01960784
  0.         0.         0.15294118 0.80784315 0.9882353  0.9882353
  0.92156863 0.07450981 0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.9490196  0.9882353  0.6431373  0.02352941
  0.12156863 0.60784316 0.9411765  0.6117647  0.6745098  0.9137255
  0.88235295 0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.9490196  0.9882353  0.85490197 0.7882353
  0.9882353  0.9137255  0.6509804  0.6039216  0.8745098  0.9882353
  0.4745098  0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.30980393 0.9882353  0.9882353  0.9882353
  0.8039216  0.25882354 0.03529412 0.13333334 0.9882353  0.9882353
  0.15294118 0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.01176471 0.4117647  0.5568628  0.3529412
  0.05098039 0.         0.19215687 0.69411767 0.9882353  0.6509804
  0.00784314 0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.43529412 1.         0.99215686 0.17254902
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.08235294 0.7647059  0.99215686 0.80784315 0.05098039
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.65882355 0.9882353  0.99215686 0.56078434 0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.12941177 0.8666667  0.9882353  0.83137256 0.01176471 0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.01176471
  0.6862745  0.9882353  0.9882353  0.43529412 0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.04705882
  0.9882353  0.9882353  0.8392157  0.09411765 0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.4627451
  0.9882353  0.9882353  0.5568628  0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.0627451  0.92156863
  0.9882353  0.9372549  0.27450982 0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.34901962 0.9882353
  0.9882353  0.36862746 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.34901962 0.9882353
  0.6509804  0.03529412 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]]
img.shape:	 (28, 28)
labels[0]:	 tensor(9)
epoch:	0 loss:	0.005 accuracy:	0.906 test_loss:	0.005 test_accuracy:	0.906
epoch:	1 loss:	0.002 accuracy:	0.958 test_loss:	0.002 test_accuracy:	0.958
epoch:	2 loss:	0.001 accuracy:	0.971 test_loss:	0.001 test_accuracy:	0.971
epoch:	3 loss:	0.001 accuracy:	0.978 test_loss:	0.001 test_accuracy:	0.978
epoch:	4 loss:	0.001 accuracy:	0.983 test_loss:	0.001 test_accuracy:	0.983
epoch:	5 loss:	0.001 accuracy:	0.987 test_loss:	0.001 test_accuracy:	0.987
epoch:	6 loss:	0.001 accuracy:	0.988 test_loss:	0.001 test_accuracy:	0.988
epoch:	7 loss:	0.0 accuracy:	0.991 test_loss:	0.0 test_accuracy:	0.991
epoch:	8 loss:	0.0 accuracy:	0.992 test_loss:	0.0 test_accuracy:	0.992
epoch:	9 loss:	0.0 accuracy:	0.993 test_loss:	0.0 test_accuracy:	0.993
epoch:	10 loss:	0.0 accuracy:	0.994 test_loss:	0.0 test_accuracy:	0.994
epoch:	11 loss:	0.0 accuracy:	0.995 test_loss:	0.0 test_accuracy:	0.995
epoch:	12 loss:	0.0 accuracy:	0.995 test_loss:	0.0 test_accuracy:	0.995
epoch:	13 loss:	0.0 accuracy:	0.995 test_loss:	0.0 test_accuracy:	0.995
epoch:	14 loss:	0.0 accuracy:	0.996 test_loss:	0.0 test_accuracy:	0.996
epoch:	15 loss:	0.0 accuracy:	0.996 test_loss:	0.0 test_accuracy:	0.996
epoch:	16 loss:	0.0 accuracy:	0.996 test_loss:	0.0 test_accuracy:	0.996
epoch:	17 loss:	0.0 accuracy:	0.997 test_loss:	0.0 test_accuracy:	0.997
epoch:	18 loss:	0.0 accuracy:	0.997 test_loss:	0.0 test_accuracy:	0.997
epoch:	19 loss:	0.0 accuracy:	0.997 test_loss:	0.0 test_accuracy:	0.997
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值