pytorch加载已有模型的方式以及使用+加载预训练模型进行部分参数复制

15 篇文章 3 订阅

一、加载已有模型直接使用

temp=torch.load("E:\\study-proj\\图像分类:从零到亿\\5.使用更多模型\\model_resnet101.pth") #加载模型,如果只有数值就只会加载模型数据,如果有字典,则会加载模型数据和字典数据
model.load_state_dict(temp)  #返回是否成功

由于模型保存的时候有保存数据和保存数据和字典的方式,所以加载的时候就有两种,利用torch.load,可将不管是数据还是数据和字典都可以加载上,但是如果只是数据,就需要将数据加载到对应的模型上,所以就有如下两种方式:

加载字典(模型对应的字典,其实就是模型变量)以及数据

model=torch.load("pth文件路径")

加载数据

temp=torch.load("E:\\study-proj\\图像分类:从零到亿\\5.使用更多模型\\model_resnet101.pth") #加载模型,如果只有数值就只会加载模型数据,如果有字典,则会加载模型数据和字典数据
model.load_state_dict(temp)  #返回是否成功

完整示例:
代码位置:https://gitee.com/sxh_and_ll/AI-CV/blob/master/proj/%E4%BD%BF%E7%94%A8pytorch%E8%87%AA%E5%B8%A6%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83%E4%BB%A5%E5%8F%8A%E6%B5%8B%E8%AF%95/load_model_test.py

'''
    加载模型,进行测试
'''
import time
import torch
from torch import nn
from torch.utils.data import DataLoader
from utils import LoadData

from torchvision.models import alexnet  #最简单的模型
from torchvision.models import vgg11, vgg13, vgg16, vgg19   # VGG系列
from torchvision.models import resnet18, resnet34,resnet50, resnet101, resnet152    # ResNet系列
from torchvision.models import inception_v3     # Inception 系列


def test(dataloader, model):
    size = len(dataloader.dataset)
    # 将模型转为验证模式
    model.eval()
    # 初始化test_loss 和 correct, 用来统计每次的误差
    test_loss, correct = 0, 0
    # 测试时模型参数不用更新,所以no_gard()
    # 非训练, 推理期用到
    with torch.no_grad():
        # 加载数据加载器,得到里面的X(图片数据)和y(真实标签)
        for X, y in dataloader:
            # 将数据转到GPU
            X, y = X.to(device), y.to(device)
            # 将图片传入到模型当中就,得到预测的值pred
            pred = model(X)
            # 计算预测值pred和真实值y的差距
            test_loss += loss_fn(pred, y).item()
            # 统计预测正确的个数
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()#返回相应维度的最大值的索引
    test_loss /= size
    correct /= size
    print(f"correct = {correct}, Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")




if __name__=='__main__':
    batch_size = 8

    # # 给训练集和测试集分别创建一个数据集加载器
    train_data = LoadData("train.txt", True)
    valid_data = LoadData("test.txt", False)

    test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size)

    # 如果显卡可用,则用显卡进行训练
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # device='cpu'
    print(f"Using {device} device")


    # 加载模型
    temp=torch.load("E:\\study-proj\\图像分类:从零到亿\\5.使用更多模型\\model_resnet101.pth") #加载模型,如果只有数值就只会加载模型数据,如果有字典,则会加载模型数据和字典数据
    model.load_state_dict(temp)
    print(model)
    
    # 定义损失函数,计算相差多少,交叉熵,
    loss_fn = nn.CrossEntropyLoss()

    # 定义优化器,用来训练时候优化模型参数,随机梯度下降法
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)  # 初始学习率

    epochs = 1
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        # print(f"train time: {(time_end-time_start)}")
        time_start=time.time()
        test(test_dataloader, model)
        time_end=time.time()
    print("Done!")

二、加载预训练模型进行参数复制

相关定义

state_dict()获得模型的网络以及参数字典

实现流程

  1. 获得相关模型以及预训练的模型参数字典;
  2. 循环遍历网络层,一致便复制相关参数,否则就不操作;
#加载预训练模型
resnet = models.resnet50(pretrained=True)
new_state_dict = resnet.state_dict()
dd = net.state_dict()
for k in new_state_dict.keys():
    print(k)
    if k in dd.keys() and not k.startswith('fc'):
        print('yes')
        dd[k] = new_state_dict[k]
net.load_state_dict(dd)
  • 1
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值