模型保存和读取、手写数字训练和测试代码示例

import torch
from torch import nn,optim

class MyNet(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()  # 调用父类的构造函数
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x

# 保存模型
def test01():
    model = MyNet(10, 1)
    torch.save(model, 'MyNet_model.pth')

# 加载模型
def test02():
    model = torch.load('MyNet_model.pth')
    print(model)

#保存模型参数
def test03():
    model=MyNet(10, 1)
    torch.save(model.state_dict(),'MyNet_model_state_dict.pth')
    print(model.state_dict())

#加载模型参数
def test04():
    state_dict = torch.load('MyNet_model_state_dict.pth')
    print(state_dict)

    model = MyNet(10, 1)
    model.load_state_dict(state_dict)

    print(model)

#保存模型参数和优化器参数
def test05():
    model = MyNet(20, 1)
    opt=optim.SGD(model.parameters(), lr=0.1)
    dic={
        "model_state_dict":model.state_dict(),
        "optimizer_state_dict":opt.state_dict(),
    }
    torch.save(dic,'MyNet_model_dict.pth')


def test06():
    dic=torch.load('MyNet_model_dict.pth')
    model = MyNet(20, 1)
    opt=optim.SGD(model.parameters(), lr=0.1)

    #加载模型参数
    model.load_state_dict(dic['model_state_dict'])
    opt.load_state_dict(dic['optimizer_state_dict'])
    print(model)
    print(opt)


if __name__ == '__main__':
    # test01()
    # test02()  # 调用加载模型的函数进行测试
    # test03()
    # test04()
    test05()
    test06()


import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from PIL import Image


class Mnist_net(nn.Module):
    def __init__(self):
        super(Mnist_net, self).__init__()
        # 1*28*28:表示图片所有通道的像素点特征
        self.fc1 = nn.Linear(1 * 28 * 28, 256)
        self.fc2 = nn.Linear(256, 32)
        self.fc3 = nn.Linear(32, 10)
        self.relu = nn.ReLU()  # 定义 relu 激活函数

    def forward(self, x):
        # 将图片数据展平为一维数组
        x = x.view(-1, 1 * 28 * 28)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 数据准备
def build_data():
    transform = transforms.ToTensor()
    # 生成训练数据集
    train_dataset = datasets.MNIST('./data', train=True, transform=transform, download=True)

    # 生成验证数据集
    val_dataset = datasets.MNIST('./data', train=False, transform=transform, download=True)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=True)
    return train_loader, val_loader


# 训练
def train(model, train_loader, epochs):
    # 定义损失函数
    criterion = nn.CrossEntropyLoss()
    # 定义优化器
    opt = optim.SGD(model.parameters(), lr=0.1)
    # 训练
    model.train()
    for epoch in range(epochs):
        # 损失
        loss_sum = 0
        # 准确度
        acc_sum = 0
        for data, label in train_loader:
            # 获取预测值
            y_pred = model(data)
            loss = criterion(y_pred, label)

            loss.backward()
            opt.step()
            opt.zero_grad()
            loss_sum += loss.item()
            _, correct = torch.max(y_pred, 1)
            acc_sum += (correct == label).sum().item()
        print(f'epoch:{epoch}, loss:{loss_sum / len(train_loader)}, acc:{acc_sum / len(train_loader.dataset)}')


def val(model, val_loader):
    model.eval()
    acc_sum = 0
    with torch.no_grad():
        for data, labels in val_loader:
            y_pred = model(data)
            _, correct = torch.max(y_pred, 1)
            acc_sum += (correct == labels).sum().item()
    print(f'acc:{acc_sum / len(val_loader.dataset)}')


def save_model(model, path):
    torch.save(model.state_dict(), path)


def load_model(path):
    return torch.load(path)  # 返回加载的模型状态字典


def test(model, filepath, modelpath):
    # 读取图片,转换为灰度图
    img = Image.open(filepath).convert('L')
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor()
    ])
    t_img = transform(img)
    t_img = t_img.unsqueeze(0)
    # 加载模型参数
    model_state_dict = load_model(modelpath)
    model.load_state_dict(model_state_dict)
    pred = model(t_img)
    _, correct = torch.max(pred, 1)
    print(f"预测类别: {correct.item()}")


if __name__ == '__main__':
    train_loader, val_loader = build_data()
    model = Mnist_net()
    epochs = 10  # 定义训练轮数
    train(model, train_loader, epochs)
    val(model, val_loader)

    # 模型保存
    modelpath = 'mnist.pt'
    save_model(model, modelpath)

    filepath = 'img/3.png'
    test(model, filepath, modelpath)  # 传递 model 参数
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值