import torch from torch import nn,optim class MyNet(nn.Module): def __init__(self, input_size, output_size): super().__init__() # 调用父类的构造函数 self.fc1 = nn.Linear(input_size, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, output_size) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # 保存模型 def test01(): model = MyNet(10, 1) torch.save(model, 'MyNet_model.pth') # 加载模型 def test02(): model = torch.load('MyNet_model.pth') print(model) #保存模型参数 def test03(): model=MyNet(10, 1) torch.save(model.state_dict(),'MyNet_model_state_dict.pth') print(model.state_dict()) #加载模型参数 def test04(): state_dict = torch.load('MyNet_model_state_dict.pth') print(state_dict) model = MyNet(10, 1) model.load_state_dict(state_dict) print(model) #保存模型参数和优化器参数 def test05(): model = MyNet(20, 1) opt=optim.SGD(model.parameters(), lr=0.1) dic={ "model_state_dict":model.state_dict(), "optimizer_state_dict":opt.state_dict(), } torch.save(dic,'MyNet_model_dict.pth') def test06(): dic=torch.load('MyNet_model_dict.pth') model = MyNet(20, 1) opt=optim.SGD(model.parameters(), lr=0.1) #加载模型参数 model.load_state_dict(dic['model_state_dict']) opt.load_state_dict(dic['optimizer_state_dict']) print(model) print(opt) if __name__ == '__main__': # test01() # test02() # 调用加载模型的函数进行测试 # test03() # test04() test05() test06()
import torch from torch import nn, optim from torch.utils.data import DataLoader from torchvision import datasets, transforms from PIL import Image class Mnist_net(nn.Module): def __init__(self): super(Mnist_net, self).__init__() # 1*28*28:表示图片所有通道的像素点特征 self.fc1 = nn.Linear(1 * 28 * 28, 256) self.fc2 = nn.Linear(256, 32) self.fc3 = nn.Linear(32, 10) self.relu = nn.ReLU() # 定义 relu 激活函数 def forward(self, x): # 将图片数据展平为一维数组 x = x.view(-1, 1 * 28 * 28) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # 数据准备 def build_data(): transform = transforms.ToTensor() # 生成训练数据集 train_dataset = datasets.MNIST('./data', train=True, transform=transform, download=True) # 生成验证数据集 val_dataset = datasets.MNIST('./data', train=False, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=128, shuffle=True) return train_loader, val_loader # 训练 def train(model, train_loader, epochs): # 定义损失函数 criterion = nn.CrossEntropyLoss() # 定义优化器 opt = optim.SGD(model.parameters(), lr=0.1) # 训练 model.train() for epoch in range(epochs): # 损失 loss_sum = 0 # 准确度 acc_sum = 0 for data, label in train_loader: # 获取预测值 y_pred = model(data) loss = criterion(y_pred, label) loss.backward() opt.step() opt.zero_grad() loss_sum += loss.item() _, correct = torch.max(y_pred, 1) acc_sum += (correct == label).sum().item() print(f'epoch:{epoch}, loss:{loss_sum / len(train_loader)}, acc:{acc_sum / len(train_loader.dataset)}') def val(model, val_loader): model.eval() acc_sum = 0 with torch.no_grad(): for data, labels in val_loader: y_pred = model(data) _, correct = torch.max(y_pred, 1) acc_sum += (correct == labels).sum().item() print(f'acc:{acc_sum / len(val_loader.dataset)}') def save_model(model, path): torch.save(model.state_dict(), path) def load_model(path): return torch.load(path) # 返回加载的模型状态字典 def test(model, filepath, modelpath): # 读取图片,转换为灰度图 img = Image.open(filepath).convert('L') transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor() ]) t_img = transform(img) t_img = t_img.unsqueeze(0) # 加载模型参数 model_state_dict = load_model(modelpath) model.load_state_dict(model_state_dict) pred = model(t_img) _, correct = torch.max(pred, 1) print(f"预测类别: {correct.item()}") if __name__ == '__main__': train_loader, val_loader = build_data() model = Mnist_net() epochs = 10 # 定义训练轮数 train(model, train_loader, epochs) val(model, val_loader) # 模型保存 modelpath = 'mnist.pt' save_model(model, modelpath) filepath = 'img/3.png' test(model, filepath, modelpath) # 传递 model 参数