加载已有的pth模型后为什么会重新训练?

保存和加载模型需要注意:

pytorch保存模型后加载模型遇到的大坑_模型保存参数再加载测试,结果跟保存前差了很多,可能的原因是什么-CSDN博客

加载已有的pth模型后为什么会重新训练?

假如有两个文件

train.py

定义神经网络
class Network(nn.Module):
    def __init__(self):
        super().__init__()

        ......

#生成对象,开始训练
network = Network()

......

#保存参数
torch.save(network,'cnn.pth')

test.py

from train import Network



network=torch.load('./cnn.pth')

在test.py运行后会重新训练train.py的模型,为什么会这样?

from train import Network 看似只导入了网络类,其实会把整个文件都导入了进来,而train.py里面训练模型的代码是全局的变量或对象,导入进来会重新运行,所以这么解决

1. 把train.py里面训练模型的代码封装成函数或者是局部的,直接加

if __name__ == "__main__":
定义神经网络
class Network(nn.Module):
    def __init__(self):
        super().__init__()

        ......

if __name__ == "__main__":

    #生成对象,开始训练
    network = Network()

    ......

    #保存参数
    torch.save(network,'cnn.pth')

2.因为保存的模型文件.pth迁移到别的地方使用需要用到定义的网络模型类,在test.py中不导入模型类了,直接把定义的模型类复制过来。

定义神经网络
class Network(nn.Module):
    def __init__(self):
        super().__init__()

        ......

#加载模型使用
network=torch.load('./cnn.pth')

加载训练好的模型进行预测,你需要使用 PyTorch 提供的 `torch.load` 函数加载模型的参数和状态字典。下面是一个加载模型并进行预测的示例: ```python import torch from torchvision import models, transforms from PIL import Image # 定义图像变换 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # 加载训练好的模型 model = models.resnet50(pretrained=True) model.eval() # 加载图像并进行预处理 image = Image.open('image.jpg') image = transform(image).unsqueeze(0) # 加载模型参数和状态字典 checkpoint = torch.load('model.pth') model.load_state_dict(checkpoint['state_dict']) # 进行预测 with torch.no_grad(): output = model(image) # 处理预测结果 _, predicted_idx = torch.max(output, 1) predicted_label = predicted_idx.item() print(f"Predicted label: {predicted_label}") ``` 在这个示例中,首先定义了图像的预处理变换,然后使用 `models.resnet50` 加载了一个预训练的 ResNet-50 模型,并将其设为评估模式。接下来,加载了要预测的图像,并进行了相同的预处理操作。然后,使用 `torch.load` 加载训练模型的参数和状态字典,并使用 `load_state_dict` 将参数加载模型中。最后,通过将图像传入模型进行预测,并处理预测结果。 请注意,这只是一个示例,你需要根据你的具体情况和模型进行相应的修改。确保模型的架构和预处理操作与训练时保持一致。 希望这个示例能帮到你!如有任何疑问,请随时提问。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

IT小艺

如果文章对你有用,请我喝咖啡吧

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值