Pytorch应用训练好的模型

本文详细介绍了PyTorch中如何保存和加载训练好的模型,包括保存模型结构和参数、仅保存参数的方式,以及针对分类问题的正确率计算。此外,还提供了CPU和GPU训练模型的完整代码,并分享了GPU训练过程中的细节优化和模型验证的方法。
摘要由CSDN通过智能技术生成

1.保存训练好的模型:torch.save方法

保存训练好的模型有两种方式,第一种保存模型结构且保存模型参数,第一种方式存在一种陷阱,也就是每次加载模型都得把类定义,或者访问类所在的包。保存方式为:

torch.save(模型名, 以pth为后缀的文件)

第二种保存方式只保存模型参数,不保存模型结构,这样可以面对较大的网络模型,可以节省空间,是官方推荐的保存方式,具体为:

torch.save(模型名.state_dict(), 以pth为后缀的文件)

第一个参数,模型名.state_dict()意为只取模型的参数,且以字典方式存储;第二个参数存储模型的地址,一般都用以pth结尾的文件。
代码如下:

import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# 陷阱
class test_Model(nn.Module):
    def __init__(self):
        super(test_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

test_model = test_Model()
torch.save(test_model, "test_model_method1.pth")

2.加载之前保存的模型:torch.load方法

对应两种不同的保存模型的方式,也相应地有两种加载模型的方式,第一种方式为:

读取出的模型名 = torch.load(之前保存的模型文件名)

第二种方式为:

读取出的模型名 = torchvision.models.vgg16(pretrained=False)
读取出的模型名.load_state_dict(torch.load(之前保存的模型文件名))

因为第二种方法没有保存模型结构,所以我们要先设计一个模型结构,本例中用的是直接下载VGG16模型结构;然后第二条语句用于将保存的参数值传入到模型结构中。代码如下:

import torch
from model_save import *
import torchvision
from torch import nn

# 方式1-》保存方式1,加载模型
model = torch.load("vgg16_method1.pth")
# print(model)

# 方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
# 只读取参数
# model = torch.load("vgg16_method2.pth")
# print(vgg16)

# 方式1,陷阱
# class test_Model(nn.Module):
#     def __init__(self):
#         super(test_Model, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#
#     def forward(self, x):
#         x = self.conv1(x)
#         return x
model = torch.load('test_model_method1.pth')
print(model)

3.对于分类问题的补充

对于图像分类、目标检测或是语义分割等各种问题而言,单靠loss曲线还无法完全表现出模型的优劣,而在图像分类问题中,我们常用acc,即在测试集上的正确率来表示模型的训练情况及优劣。

在分类问题中,例如CIFAR10中,某张图片进入模型后的输出结果是形如([0.2, 0.1, 0.4, 0.6, 0.4, 0.4, 0.7, 0.2, 0.5, 0.1]),而targets是类似(5)这样的某个整型数字,表明该图片属于第5类。那么如和将输出结果与targets之间进行计算,产生正确率呢,我们采用argmax函数,该函数会找到数组中最大值并输出最大值的序号,那么如此一来,将输出结果的最大值序号和targets相比较,如果一致则说明该图像识别正确。代码如下:

import torch

output = torch.tensor([[0.2, 0.3],
                      [0.1, 0.5]])
# 纵向比较
print(output.argmax(0))
# 横向比较
print(output.argmax(1))

pred = output.argmax(1)
targets = torch.tensor([0, 1])
print(pred == targets)
# 输出正确的个数
print((pred == targets).sum())

注:因在验证过程中,因为不是一张图片,而是batch_size张图片验证,那么output应是二维数组,第二维的数目是batchsize个,此例中设为2个,相应的targets应是一个一维数组,数组中的每个数表示每张图对应的正确类别是哪类。另外,argmax(0)表示传入的数据纵向进行比较,在本例中是0.2和0.1比,0.3和0.5比;而argmax(1)表示传入的数据横向进行比较,在本例中是0.2和0.3比,0.1和0.5比,毫无疑问,运用在计算准确率中,我们应该使用的是argmax(1),那么记录所有output和targets相吻合的项的和除以总项数即可获得该batch_size的准确率。

4.CPU训练完整代码

CPU训练就是之前介绍的搭建模型和应用模型相结合,内容都是之前所讲,完整代码如下:

import torchvision
from torch.utils.tensorboard import SummaryWriter

from model import *
# 准备数据集
from torch import nn
from torch.utils.data import DataLoader

train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)

# length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 如果train_data_size=10, 训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))


# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# 创建网络模型
tudui = Tudui()

# 损失函数
loss_fn = nn.CrossEntropyLoss()

# 优化器
# learning_rate = 0.01
# 1e-2=1 x (10)^(-2) = 1 /100 = 0.01
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 
  • 44
    点赞
  • 341
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值