《动手学深度学习》7.3网络中的网络(NiN)

《动手学深度学习》7.3网络中的网络(NiN)

  • 导入功能包
import torch
from torch import nn
import MyFunction as MF
  • 参数设置
lr, num_epochs, batch_size = 0.1, 10, 128
  • 读取数据
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_iter, test_iter = MF.load_data_fashion_mnist(batch_size, resize=224)
  • NiN模型
# 定义NiN块
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    blk = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU())
    return blk

# 定义NiN网络模型
net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, strides=4, padding=0),
    nn.MaxPool2d(3, stride=2),
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    nn.MaxPool2d(3, stride=2),
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    nn.MaxPool2d(3, stride=2), nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    nn.AdaptiveAvgPool2d((1, 1)),
    # 将四维的输出转成二维的输出,其形状为(批量大小, 10)
    nn.Flatten())
  • 训练
MF.train_ch7(net, train_iter, test_iter, num_epochs, lr, device=device)
  • 训练结果
    在这里插入图片描述

预测

import torch
from torch import nn
import MyFunction as MF
import matplotlib.pyplot as plt

batch_size = 128
lr, num_epochs = 0.9, 10
train_iter, test_iter = MF.load_data_fashion_mnist(batch_size=batch_size,resize=224)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义NiN块
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    blk = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU())
    return blk

# 定义NiN网络模型
net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, strides=4, padding=0),
    nn.MaxPool2d(3, stride=2),
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    nn.MaxPool2d(3, stride=2),
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    nn.MaxPool2d(3, stride=2), nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    nn.AdaptiveAvgPool2d((1, 1)),
    # 将四维的输出转成二维的输出,其形状为(批量大小, 10)
    nn.Flatten())

# 加载模型参数
net.load_state_dict(torch.load("./data/NiN-9.pth"))
# 预测
MF.predict_ch7(net, test_iter)
plt.show()
  • 预测结果

在这里插入图片描述

MyFunction包里定义的相关函数

  • 功能包
import torch
import torchvision
from torchvision import transforms
from torch.utils import data
import sys
import MyFunction as MF
from torch import nn
import numpy as np
import time
from tqdm import tqdm
  • 训练函数
# 训练函数(增加了进度条)
def train_ch7(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)。"""
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)

    net.apply(init_weights)
    print(f'training on {device}:{torch.cuda.get_device_name()}')
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()

    timer = MF.Timer()
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
        net.train()
        with tqdm(train_iter) as t:
            for X, y in t:
                timer.start()
                optimizer.zero_grad()
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                l = loss(y_hat, y)
                l.backward()
                optimizer.step()
                train_l_sum += l.item() * X.shape[0]
                train_acc_sum += MF.accuracy(y_hat,y)
                n += y.shape[0]
                train_l = train_l_sum / n
                train_acc = train_acc_sum / n
                timer.stop()
                # 设置进度条左边显示的信息
                t.set_description(f"epoch:{epoch}")
                # 设置进度条右边显示的信息
                t.set_postfix(loss="%.3f" % train_l, train_acc="%.3f" % train_acc, time="%.3f sec" % timer.stop())

            torch.save(net.state_dict(),"./data/NiN-%d.pth" %(epoch))
            test_acc = MF.evaluate_accuracy_gpu_ch6(net, test_iter)
    print(f'epoch:{epoch+1},loss {train_l:.3f}, train_acc {train_acc:.3f}, test_acc {test_acc:.3f}, {timer.stop()} sec')
    print(f'{n* num_epochs / timer.sum():.1f} examples/sec '
          f'on {str(device)}')
  • 预测函数
def predict_ch7(net, test_iter, n=6):
    """预测标签"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for X, y in test_iter:
        h, w = X.shape[-2:]
        X.to(device)
        y.to(device)
        # 真实标签
        trues = MF.get_fashion_mnist_labels(y)
       # 预测标签
        preds = MF.get_fashion_mnist_labels(net(X).argmax(axis=1))
        titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
        MF.show_images(X[0:n].reshape((n,h,w)), 1, n, titles=titles[0:n])

        break
  • 绘图函数
# 获取标签文本
def get_fashion_mnist_labels(labels):
    """返回Fashion-MNIST数据集的文本标签。"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值