pytorch: 学习笔记7, pytorch实现softmax回归

pytorch实现softmax回归
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
from collections import OrderedDict

# Fashion-MNIST是一个10类服饰分类数据集
mnist_train = torchvision.datasets.FashionMNIST(root='Datasets/FashionMNIST', train=True, download=False, transform=transforms.ToTensor())  # download=True表示从网络下砸死数据集
mnist_test = torchvision.datasets.FashionMNIST(root='Datasets/FashionMNIST', train=False, download=False, transform=transforms.ToTensor())
print(type(mnist_train))  # <class 'torchvision.datasets.mnist.FashionMNIST'>
print(len(mnist_train), len(mnist_test))  # 60000 10000

feature, label = mnist_train[0]  # 第0张 图片及其label
print(feature.shape, label)  # (Channel x Height x Width)torch.Size([1, 28, 28]) 9

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

def show_fashion_mnist(images, labels):
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

# 训练数据集中前10个样本的图像内容和文本标签
def show_fashion_mnist_first_ten_images():
    X, y = [], []
    for i in range(10):
        X.append(mnist_train[i][0])  # mnist_train 有60000万个样本,每个样本有图片: mnist_train[i][0], 及对应的标签0~9: mnist_train[i][1]
        y.append(mnist_train[i][1])
    show_fashion_mnist(X, get_fashion_mnist_labels(y))

# 读取batch_size 批量数据
def get_train_test_batchsize_data(batch_size=256):
    if sys.platform.startswith('win'):
        num_workers = 0  # 0表示不用额外的进程来加速读取数据
    else:
        num_workers = 4
    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_iter, test_iter

# start = time.time()
# for X, y in train_iter:
#     continue
# print('%.2f sec' % (time.time() - start))

# 定义网络: method 1
# class LinearNet(torch.nn.Module):
#     def __init__(self, num_inputs, num_outputs):
#         super(LinearNet, self).__init__()
#         self.linear = torch.nn.Linear(num_inputs, num_outputs)
#     def forward(self, x): # x shape: (batch, 1, 28, 28)
#         y = self.linear(x.view(x.shape[0], -1))  # 每个batch样本x的形状为(batch_size, 1, 28, 28), 所以先用view()将x的形状转换成(batch_size, 784)才送入全连接层
#         return y
# # num_inputs= 784  # 28 * 28
# # num_outputs = 10  # 10分类
# # net = LinearNet(num_inputs, num_outputs)

# 对x的形状进行转换
class FlattenLayer(torch.nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()
    def forward(self, x): # x shape: (batch, *, *, ...)
        return x.view(x.shape[0], -1)

# 评估
def evaluate_accuracy(data_iter, net):
    acc_sum, n = 0.0, 0
    for X, y in data_iter:
        acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
        n += y.shape[0]
    return acc_sum / n

# 训练,随机梯度下降
def sgd(params, lr, batch_size):
    for param in params:
        param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data

# 训练
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, optimizer=None):
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y).sum()

            # 梯度清零
            if optimizer is not None:
                optimizer.zero_grad()
            elif params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()

            l.backward()
            if optimizer is None:
                sgd(params, lr, batch_size)
            else:
                optimizer.step()  # “softmax回归的简洁实现”一节将用到


            train_l_sum += l.item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
            n += y.shape[0]
        test_acc = evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
              % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

if __name__ == '__main__':
    # 1, 显示fashion mnist 前10张数据
    # show_fashion_mnist_first_ten_images()

    # 2, 获取batch size批数据集
    batch_size = 256
    train_iter, test_iter = get_train_test_batchsize_data(batch_size = 256)

    # 3, 定义网络 method2
    num_inputs= 784  # 28 * 28
    num_outputs = 10  # 10分类
    net = torch.nn.Sequential(
        # FlattenLayer(),
        # nn.Linear(num_inputs, num_outputs)
        OrderedDict([
            ('flatten', FlattenLayer()),
            ('linear', torch.nn.Linear(num_inputs, num_outputs))
        ])
    )
    #4, 初始化权重
    torch.nn.init.normal_(net.linear.weight, mean=0, std=0.01)
    torch.nn.init.constant_(net.linear.bias, val=0)

    # 5, 定义损失函数:PyTorch提供了一个包括softmax运算和交叉熵损失计算的函数
    loss = torch.nn.CrossEntropyLoss()

    # 6, 定义优化算法
    optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

    # 7, 训练模型
    num_epochs = 5
    train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)
结果:

<class ‘torchvision.datasets.mnist.FashionMNIST’>
60000 10000
torch.Size([1, 28, 28]) 9
epoch 1, loss 0.0031, train acc 0.747, test acc 0.766
epoch 2, loss 0.0022, train acc 0.812, test acc 0.807
epoch 3, loss 0.0021, train acc 0.825, test acc 0.806
epoch 4, loss 0.0020, train acc 0.833, test acc 0.787
epoch 5, loss 0.0019, train acc 0.836, test acc 0.825

参考学习,把学习中的知识整合,并非自己实现。
https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.7_softmax-regression-pytorch

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值