【Pytorch实战系列】NiN训练FashionMNIST数据集

目录

模型结构

现有的挑战:

NiN使用的策略:

NiN块:

代码

训练结果


模型结构

现有的挑战

  • 末尾全连接层的参数量太大
  • 无法在网络开始的地方增加全连接层来增加非线性程度,因为这样会破坏空间结构,同时开销也很大

NiN使用的策略

  • 使用 1x1 的卷积来增加通道激活的局部非线性
  • 使用平均池化来融合最后表征层的所有位置,平均池化要想有效必须增加非线性因素

NiN块

NiN 块和 VGG 块有两个不同点,一是第一层卷积后面跟的是 1x1 卷积,而不再是 3x3 卷积了;二是最后没有了全连接层。

代码

import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
# 加载数据
train_dataset = datasets.FashionMNIST(root="../../../datasets/", transform=transforms.Compose([transforms.ToTensor(), transforms.Resize(224)]), train=True, download=False)
test_dataset = datasets.FashionMNIST(root="../../../datasets/", transform=transforms.Compose([transforms.ToTensor(), transforms.Resize(224)]), train=False, download=False)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# 定义NiN网络结构
class NiN(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = self.nin_block(1, 96, 11, 4, 0)
        self.block2 = self.nin_block(96, 256, 5, 1, 1)
        self.block3 = self.nin_block(256, 384, 3, 1, 1)
        self.block4 = self.nin_block(384, 10, 3, 1, 1)
        self.maxpool = nn.MaxPool2d(3, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size:1x1
        self.dropout = nn.Dropout(0.5)
        self.flatten = nn.Flatten()

    def nin_block(self, in_channels, out_channels, kernel_size, strides, padding):
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding), nn.ReLU(),
                         nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),
                         nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())
    def forward(self, x):
        h1 = self.block1(x)
        h2 = self.maxpool(h1)
        h3 = self.block2(h2)
        h4 = self.maxpool(h3)
        h5 = self.block3(h4)
        h6 = self.maxpool(h5)
        h6 = self.dropout(h6)
        h7 = self.block4(h6)
        h8 = self.avgpool(h7)
        return self.flatten(h8)

# 打印模型结构用
# net = NiN()
# summary(net, (1, 1, 224, 224))


device = "cuda:0" if torch.cuda.is_available() else "cpu"
nin = NiN().to(device)
# 定义超参数
epochs = 10
lr = 1e-3
# 定义优化器
optimizer = torch.optim.Adam(nin.parameters(), lr = lr)
# 定义损失函数
loss_fn = nn.CrossEntropyLoss()
# 训练
for epoch in range(epochs):
    train_loss_epoch = []
    for train_data, labels in tqdm(train_dataloader):
        train_data = train_data.to(device)
        labels = labels.to(device)
        y_hat = nin(train_data)
        loss = loss_fn(y_hat, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss_epoch.append(loss.cpu().detach().numpy())
    
    print(f'epoch:{epoch}, train_loss:{sum(train_loss_epoch) / len(train_loss_epoch)}')
    with torch.no_grad():
        test_loss_epoch = []
        right = 0
        for test_data, labels in tqdm(test_dataloader):
            test_data = test_data.to(device)
            labels = labels.to(device)
            y_hat = nin(test_data)
            loss = loss_fn(y_hat, labels)
            test_loss_epoch.append(loss.cpu().detach().numpy()) 
            right += (torch.argmax(y_hat, 1) == labels).sum()
        acc = right / len(test_dataset)
        print(f'test_loss:{sum(test_loss_epoch) / len(test_loss_epoch)}, acc:{acc}')

训练结果

参数还可以进一步调优

  • 11
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值