# 在MNIST数据集上训练一个手写数字识别模型

1.1 数据下载

import torchvision as tv

testing_sets = tv.datasets.MNIST(root="", train=False, download=True)

1.2 数据增强

import torchvision.transforms as tf

transform = tf.Compose([
tf.ToTensor(),
tf.Normalize((0.1307), (0.3081)),
tf.RandomAffine(translate=(0.2, 0.2), degrees=0),
tf.RandomRotation((-20, 20))
])

1.3 数据加载

from torch.utils.data import DataLoader
import torchvision as tv

# 生成测试集和训练集的迭代器
tv.datasets.MNIST("data", transform=transform),
batch_size=batch_size,
shuffle=True,
num_workers=10)  # num_workers 根据机子实际情况设置，一般<=CPU数

tv.datasets.MNIST("data", transform=tf.ToTensor(), train=False),
batch_size=batch_size,
shuffle=True,
num_workers=10)

2.1 CNN识别模型

该模型来自这里，该论文也提供了完整的训练代码在这里。左图的模型由10层卷积层和一层全连接层组成。该模型没有采用最大值池化或均值池化，每层卷积层由一次3*3卷积，一个批次归一化单元和一个ReLU激活函数组成。除第一次卷积外，每次卷积的通道数加16。因为在步长为1，padding为0时，经过3*3卷积的特征图高宽各减2，所以最后一层卷积层输出的特征图大小为 [m, 176,8,8]. 展开后 [m, 176*8*8]。最后采用的损失函数是交叉熵损失。

class M3(nn.Module):
def __init__(self) -> None:
super(M3, self).__init__()

self.conv_list = nn.ModuleList([self.bn2d(1, 32)])

for i in range(32, 161, 16):
self.conv_list.append(self.bn2d(i, i+16))

self.FC = nn.Sequential(
nn.Flatten(),
nn.Linear(11264, 10), nn.BatchNorm1d(10),
nn.LogSoftmax(dim=1)
)

def forward(self, x):
for conv in self.conv_list:
x = conv(x)
return self.FC(x)

def bn2d(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)

3.1 其他准备

model = M3() # 实例化模型
# 使用Adam优化器，初始学习率1e-3，其他参数全部默认就行

# 应用学习率衰减，采用指数衰减，衰减率为0.9
scheduler = opt.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.9)
# 损失函数, NLLLoss与LogSoftmax配合 == CrossEntropyLoss
criterion = nn.NLLLoss()

3.2 训练代码

from model import M3
import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as tf
import torch.optim as opt
from typing import *
import numpy as np
import matplotlib.pyplot as plt
import copy

# 是否在GPU上训练 ----------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available()
else "cpu")

# 超参 --------------------------------------------------------
batch_size: int = 120
learning_rate: float = 1e-3
gamma: float = 0.98
epochs: int = 120

# 数据增强操作 ------------------------------------------------
transform = tf.Compose([
tf.ToTensor(),
tf.Normalize((0.1307), (0.3081)),
tf.RandomAffine(translate=(0.2, 0.2), degrees=0),
tf.RandomRotation((-20, 20))
])

# 生成测试集和训练集的迭代器 -----------------------------------
train_sets = tv.datasets.MNIST("data", transform=transform)
test_sets = tv.datasets.MNIST("data", transform=tf.Compose([tf.ToTensor(),
tf.Normalize((0.1307), (0.3081))]),
train=False)

train_sets,
batch_size=batch_size,
shuffle=True,
num_workers=10)  # num_workers 根据机子实际情况设置，一般<=CPU数

test_sets,
batch_size=batch_size,
shuffle=True,
num_workers=10)

data_size = {"train": len(train_sets), "test": len(test_sets)}

# 导入模型 ---------------------------------------------------
model = M3().to(device)

# 损失函数, 与LogSoftmax配合 == CrossEntropyLoss -------------
criterion = nn.NLLLoss().to(device=device)

# 学习率指数衰减, 每轮迭代学习率更新为原来的gamma倍；verbose为True时，打印学习率更新信息
scheduler = opt.lr_scheduler.ExponentialLR(
optimizer=optimizer, gamma=gamma, verbose=True)

# 训练 -------------------------------------------------------

def train_model(num_epochs: int = 1):

best_acc: int = 0  # 记录测试集最高的预测得分, 用于辅助保存模型参数
epochs_loss = np.array([])  # 记录损失变化，用于可视化模型
epochs_crr = np.array([])  # 记录准确率变化，用于可视化模型

for i in range(num_epochs):
# 打印迭代信息
print(f"epochs {i+1}/{num_epochs} :")

# 每次迭代再分训练和测试两个部分
for phase in ["train", "test"]:

if phase == "train":
model.train()
else:
model.eval()

running_loss: float = 0.0
running_crr: int = 0

# 如果是在GPU上训练，数据需进行转换
inputs, labels = inputs.to(device), labels.to(device)

# 测试模型时不需要计算梯度
outputs = model.forward(inputs)
loss = criterion(outputs, labels)

# 是否回溯更新梯度
if phase != "test":
loss.backward()
optimizer.step()

# 记录损失和准确预测数
_, preds = torch.max(outputs, 1)
running_crr += (preds == labels.data).sum()
running_loss += loss.item() * inputs.size(0)

# 打印结果和保存参数
acc = (running_crr/data_size[phase]).cpu().numpy()
if phase == "test" and acc > best_acc:
best_acc = acc
model_duplicate = copy.deepcopy(model)

avg_loss = running_loss/data_size[phase]
print(f"{phase} >>> acc:{acc:.4f},{running_crr}/{data_size[phase]}; loss:{avg_loss:.5f}")

# 保存损失值和准确率
epochs_crr = np.append(epochs_crr, acc)
epochs_loss = np.append(epochs_loss, avg_loss)

print(f"best test acc {best_acc:.4f}")
scheduler.step()  # 更新学习率

return model_duplicate, best_acc, epochs_crr, epochs_loss

def visualize_model(epochs_crr: np.array, epochs_loss: np.array):

# 分离训练和测试损失与准确率
epochs_crr = epochs_crr.reshape(-1, 2)
epochs_loss = epochs_loss.reshape(-1, 2)
train_crr, test_crr = epochs_crr[:, 0], epochs_crr[:, 1]
train_lss, test_lss = epochs_loss[:, 0], epochs_loss[:, 1]

# 绘制准确率变化图
plt.subplot(1,2,1)
plt.plot(np.arange(len(train_crr)), train_crr, "-g", label="train")
plt.plot(np.arange(len(test_crr)), test_crr, "-m", label="test")
plt.title("accuracy")
plt.xlabel("epochs")
plt.ylabel("acc")
plt.legend()
plt.grid()

# 绘制损失值变化图
plt.subplot(1,2,2)
plt.plot(np.arange(len(train_lss)), train_lss, "-g", label="train")
plt.plot(np.arange(len(test_lss)), test_lss, "-m", label="test")
plt.title("accuracy")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.legend()
plt.grid()

# 保存结果
plt.tight_layout()
plt.savefig(f"result.png")
plt.close()

if __name__ == "__main__":
trained_model, best_test_acc, crr, lss = train_model(epochs)
visualize_model(crr, lss)
torch.save(trained_model.state_dict(), f"model_params_{best_test_acc}.pth")



3.3 可视化结果

4.1 其他补充

torchsummary查看模型的详细细节

    import torchsummary
torchsummary.summary(model=M3(), input_size=(1, 28, 28))

----------------------------------------------------------------
Layer (type)               Output Shape         Param #
================================================================
Conv2d-1           [-1, 32, 26, 26]             320
BatchNorm2d-2           [-1, 32, 26, 26]              64
ReLU-3           [-1, 32, 26, 26]               0
Conv2d-4           [-1, 48, 24, 24]          13,872
BatchNorm2d-5           [-1, 48, 24, 24]              96
ReLU-6           [-1, 48, 24, 24]               0
Conv2d-7           [-1, 64, 22, 22]          27,712
BatchNorm2d-8           [-1, 64, 22, 22]             128
ReLU-9           [-1, 64, 22, 22]               0
Conv2d-10           [-1, 80, 20, 20]          46,160
BatchNorm2d-11           [-1, 80, 20, 20]             160
ReLU-12           [-1, 80, 20, 20]               0
Conv2d-13           [-1, 96, 18, 18]          69,216
BatchNorm2d-14           [-1, 96, 18, 18]             192
ReLU-15           [-1, 96, 18, 18]               0
Conv2d-16          [-1, 112, 16, 16]          96,880
BatchNorm2d-17          [-1, 112, 16, 16]             224
ReLU-18          [-1, 112, 16, 16]               0
Conv2d-19          [-1, 128, 14, 14]         129,152
BatchNorm2d-20          [-1, 128, 14, 14]             256
ReLU-21          [-1, 128, 14, 14]               0
Conv2d-22          [-1, 144, 12, 12]         166,032
BatchNorm2d-23          [-1, 144, 12, 12]             288
ReLU-24          [-1, 144, 12, 12]               0
Conv2d-25          [-1, 160, 10, 10]         207,520
BatchNorm2d-26          [-1, 160, 10, 10]             320
ReLU-27          [-1, 160, 10, 10]               0
Conv2d-28            [-1, 176, 8, 8]         253,616
BatchNorm2d-29            [-1, 176, 8, 8]             352
ReLU-30            [-1, 176, 8, 8]               0
Flatten-31                [-1, 11264]               0
Linear-32                   [-1, 10]         112,650
BatchNorm1d-33                   [-1, 10]              20
LogSoftmax-34                   [-1, 10]               0
================================================================
Total params: 1,125,230
Trainable params: 1,125,230
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 5.70
Params size (MB): 4.29
Estimated Total Size (MB): 9.99
----------------------------------------------------------------

• 0
点赞
• 2
收藏
觉得还不错? 一键收藏
• 打赏
• 2
评论
02-15 5080
03-28 2万+
04-14 567
07-16 2万+
11-07 8万+
10-23 832
06-01 1331

### “相关推荐”对你有帮助么？

• 非常没帮助
• 没帮助
• 一般
• 有帮助
• 非常有帮助

Satttt

¥1 ¥2 ¥4 ¥6 ¥10 ¥20

1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载，可以购买VIP、付费专栏及课程。