今天实现了一下MNIST手写体数据集的训练,代码中并没有使用测试集。
"""
@Title: 训练手写体0-9数字
@Time: 2023/11/26 14:01
@Author: Michael
"""
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
# 搭建神经网络
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.model1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2), # (-1,6,28,28)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), # (-1,6,14,14)
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), # (-1,16,10,10)
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # (-1,16,5,5)
nn.Flatten(),
nn.Linear(in_features=16 * 5 * 5, out_features=120), # (-1,120)
nn.ReLU(),
nn.Linear(120, 84), # (-1,84)
nn.ReLU(),
nn.Linear(in_features=84, out_features=10) # (-1,10)
)
def forward(self, x):
x = self.model1(x)
return x
# 训练
def train():
# 初始化
lr = 0.01 # 学习率
batch_size = 100 # 批次大小
epoch = 50 # 设置训练参数
model = Model() # 创建网络模型
loss_fun = nn.CrossEntropyLoss() # 创建损失函数
optimizer = torch.optim.SGD(model.parameters(), # 创建优化器
lr=lr)
writer = SummaryWriter(log_dir='logs') # 创建日志
if torch.cuda.is_available():
model = model.cuda()
loss_fun = loss_fun.cuda()
# 下载训练集
train_set = torchvision.datasets.MNIST(root="mnist_dataset",
train=True,
transform=torchvision.transforms.ToTensor())
# 加载数据集
train_data_loader = DataLoader(dataset=train_set, batch_size=batch_size)
# 训练
for i in range(epoch):
loss_sum = 0
# 训练集
for data in train_data_loader:
img, target = data
if torch.cuda.is_available():
img = img.cuda()
target = target.cuda()
optimizer.zero_grad() # 梯度清零
outputs = model(img) # 向前传播
loss = loss_fun(outputs, target) # loss
loss.backward() # 反向传播
optimizer.step() # 参数更新
loss_sum += loss.item()
# 写入日志
print("训练{}次,train_loss:{}".format(i + 1, loss_sum))
writer.add_scalar("train_loss", loss_sum, i + 1)
# 保存训练结果
torch.save(model, "model.pkl")
writer.close()
# 应用
def practice(img_path, model_path):
# 加载图片
img = Image.open(img_path)
img = img.convert('L')
# 预处理
pro_img = torchvision.transforms.Compose([torchvision.transforms.Resize((28, 28)),
torchvision.transforms.ToTensor()])
img = pro_img(img)
img = torch.reshape(img, (1, 1, 28, 28))
# 加载模型
model = torch.load(model_path)
# 预测输出
if torch.cuda.is_available():
img = img.cuda()
output = model(img)
output = output.argmax(1)
dict_target = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
print('识别类型为:{}'.format(dict_target[output]))
if __name__ == '__main__':
practice("1.png", "model.pkl")