保存和加载模型需要注意:
pytorch保存模型后加载模型遇到的大坑_模型保存参数再加载测试,结果跟保存前差了很多,可能的原因是什么-CSDN博客
加载已有的pth模型后为什么会重新训练?
假如有两个文件
train.py
定义神经网络
class Network(nn.Module):
def __init__(self):
super().__init__()
......
#生成对象,开始训练
network = Network()
......
#保存参数
torch.save(network,'cnn.pth')
test.py
from train import Network
network=torch.load('./cnn.pth')
在test.py运行后会重新训练train.py的模型,为什么会这样?
from train import Network 看似只导入了网络类,其实会把整个文件都导入了进来,而train.py里面训练模型的代码是全局的变量或对象,导入进来会重新运行,所以这么解决
1. 把train.py里面训练模型的代码封装成函数或者是局部的,直接加
if __name__ == "__main__":
定义神经网络
class Network(nn.Module):
def __init__(self):
super().__init__()
......
if __name__ == "__main__":
#生成对象,开始训练
network = Network()
......
#保存参数
torch.save(network,'cnn.pth')
2.因为保存的模型文件.pth迁移到别的地方使用需要用到定义的网络模型类,在test.py中不导入模型类了,直接把定义的模型类复制过来。
定义神经网络
class Network(nn.Module):
def __init__(self):
super().__init__()
......
#加载模型使用
network=torch.load('./cnn.pth')