pytorch学习笔记(六)

一、多分类问题相关知识

多分类问题实战:MNIST数据集是经典图像数据集,包括10个类别(0到9)。每一张图片拉成向量表示。
MNIST 数据集(手写数字数据集)来自美国国家标准与技术研究所. 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50%来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。MNIST数据集下载地址: http://yann.lecun.com/exdb/mnist/。手写数字的MNIST数据库包括60,000个的训练集样本,以及10,000个测试集样本。

二、代码实现

import torch
import torch.nn as nn
import torchvision as tv

# 超参数
batch_size=200
learning_rate=0.01
epochs=10

# 训练集
train_loader = torch.utils.data.DataLoader(
    tv.datasets.MNIST('../data', train=True, download=True,          # train=True则得到的是训练集
                   transform=tv.transforms.Compose([                 # transform进行数据预处理
                       tv.transforms.ToTensor(),                     # 转成Tensor类型的数据
                       tv.transforms.Normalize((0.1307,), (0.3081,)) # 进行数据标准化(减去均值除以方差)
                   ])),
    batch_size=batch_size, shuffle=True)                          # 按batch_size分出一个batch维度在最前面,shuffle=True打乱顺序



# 测试集
test_loader = torch.utils.data.DataLoader(
    tv.datasets.MNIST('../data', train=False, transform=tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)

# 设定参数w和b
w1, b1 = torch.randn(200, 784, requires_grad=True),\
         torch.zeros(200, requires_grad=True)             # w1(out,in)
w2, b2 = torch.randn(200, 200, requires_grad=True),\
         torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True),\
         torch.zeros(10, requires_grad=True)

torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)


def forward(x):
    x = x@w1.t() + b1
    x = torch.nn.function.relu(x)
    x = x@w2.t() + b2
    x = torch.nn.function.relu(x)
    x = x@w3.t() + b3
    x = torch.nn.function.relu(x)
    return x


#定义sgd优化器,指明优化参数、学习率
optimizer = torch.optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criteon = nn.CrossEntropyLoss()

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)

        logits = forward(data)               # 把数据放入神经网络得出pred的值
        loss = criteon(logits, target)       # 用loss函数计算pred和target的差
        optimizer.zero_grad()                # 清零梯度
        loss.backward()                      # 重新计算梯度
        optimizer.step()                     # 用新的梯度计算新的w,b,然后迭代

        if batch_idx % 100 == 0:             #每100个batch输出一次信息
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0                                         #correct记录正确分类的样本数
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        logits = forward(data)
        test_loss += criteon(logits, target).item()     #其实就是criteon(logits, target)的值,标量

        pred = logits.data.max(dim=1)[1]                # 得出pred的最大值,就是网络识别出的数字
        correct += pred.eq(target.data).sum()

    test_loss /= len(test_loader.dataset)
    print(' {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值