用torch实现手写识别算法


# torch 经典操作,引入nn和F,里面有神经网络的各种操作
from torch import nn
from torch.nn import functional as F

# 各种优化算法的库
from torch import optim

# 画图
import torchvision
from matplotlib import pyplot as plt

# 涉及一些其他操作,如引入手写设别数据集
import torch

# 绘制曲线的方法
def plot_curve(data):
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()

# 绘制图片的方法,用于显示手写识别
def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

# 抓转换为独热码
def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

batch_size = 512  # 一次训练所选取的样本数,这里是我们自己设定的

# 步骤一:加载数据集(训练集和测试集)
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(  # 均匀分布在0附近
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size, shuffle=False)

# iter和next:
# list、tuple等都是可迭代对象,我们可以通过iter()函数获取这些可迭代对象的迭代器。然后,我们可以对获取到的迭代器不断使⽤next()函数来获取下⼀条数据。
x, y = next(iter(train_loader))
print(x.shape, y.shape)

plot_image(x, y, "imageSimple")

# 步骤二:定义神经网络
# 1.需要继承nn.Module类,并实现forward方法
# 2.继承nn.Module类之后,在构造函数中要调用Module的构造函数, super(Linear, self).init()
# 3.一般把网络中具有可学习参数的层放在构造函数__init__()中。
class Net(nn.Module):

    # 定义神经网络
    def __init__(self):
        super(Net, self).__init__()

        # xw+b
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    # 前向传播
    def forward(self, x):
        # x: [b,1,28,28]  b张图片,这里b为512
        # h1 = relu(xw+b)
        x = F.relu(self.fc1(x))
        # h1 = relu(xw+b)
        x = F.relu(self.fc2(x))
        # h1 = xw+b
        x = self.fc3(x)

        return x


def trainAndTest():
    net = Net()
    optimizer = optim.SGD(net.parameters(),lr = 0.01,momentum = 0.9)        # 优化器,定义学习率,momentum是冲量
    # momentum:
    # 当本次梯度下降- dx * lr的方向与上次更新量v的方向相同时,上次的更新量能够对本次的搜索起到一个正向加速的作用。
    # 当本次梯度下降- dx * lr的方向与上次更新量v的方向相反时,上次的更新量能够对本次的搜索起到一个减速的作用。

    # 步骤三:迭代训练
    train_loss = []     # 记录loss的变化,用来画图
    # 迭代三次
    for epoch in range(3):
        # 对一个batch迭代一次
        for batch_idx, (x, y) in enumerate(train_loader):
            #       enumerate:
            #       seasons = ['Spring', 'Summer', 'Fall', 'Winter']
            #       list(enumerate(seasons))
            #       [(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]

            # x:[512, 1, 28, 28]    y:[512]
            # => x:[512,10]
            x = x.view(x.size(0),28*28)     # view()的作用相当于numpy中的reshape
            # => [b,10]
            out = net(x)
            # class Module(nn.Module): 中的nn.Module存在魔法方法:__call__ 他调用了 forward函数。使得 net(data) 等价于 net.forward(data) !!!
            y_onehot = one_hot(y)
            loss = F.mse_loss(out,y_onehot)     # 均方差

            optimizer.zero_grad()      # 清空过往梯度
            loss.backward()             # 反向传播,计算当前梯度
            # w' = w - lr*grad
            optimizer.step()        # 根据梯度更新网络参数

            train_loss.append(loss.item())

            if batch_idx % 10 == 0:
                if batch_idx % 30 == 0:
                    print("迭代次数", '\t', "批数", '\t', "loss值")
                print(epoch,'\t',batch_idx,'\t',loss.item())
#                 item()取出张量具体位置的元素元素值

#   训练完,画出loss的变化值
    plot_curve(train_loss)

    # 步骤四:模型在测试集上训练
    total_correct = 0

    for x,y in test_loader:
        # 这里一次循环是一批操作,即一次处理512
        x = x.view(x.size(0),28*28)
        out = net(x)
        # out:[b,10] => pred:[b]
        # out:[512,10] ----dim = 1----> pred: [512]
        pred = out.argmax(dim = 1)
        correct = pred.eq(y).sum().float().item()       # pred中和y相等的,的综合,并转化为小数形式,并取出张量的元素值
        total_correct += correct

    total_num = len(test_loader.dataset)
    acc = total_correct / total_num
    print("test acc:",acc)

    # 用具体图片展示,这里展示第一批
    x,y = next(iter(test_loader))
    print(x.shape,y.shape)
    out = net(x.view(x.size(0),28*28))
    pred = out.argmax(dim=1)
    plot_image(x,pred,'test')

trainAndTest()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值