pytorch学习(二):模型的保存与加载

我使用的是pytorch1.5 的版本
这里主要涉及到两个函数

官网的解释 https://pytorch.org/docs/1.5.0/nn.html
save保存模型
    torch.save(obj, f, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=False)
        obj : 你要保存的对象
        f : 保存的文件位置
        后面的参数不管,默认就行
    举例
        x = torch.tensor([0, 1, 2, 3, 4])
        torch.save(x, 'tensor.pt')
   load加载模型     
    torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
        f : 文件对应的位置
        map_location : 描述torch的device,加载的时候的容器保存位置
    举例:
        torch.load('tensors.pt', map_location=torch.device('cpu'))
        torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

保存的时候要
要用state_dict()这个函数保存模型,再save
如:
state={‘net’:model.state_dict(),‘optimizer’:optimiser.state_dict(),‘epoch’:epochs}
torch.save(state,‘D:\AllPro\PytorchPro\two.pth’)
加载的时候反之
torch.load(‘D:\AllPro\PytorchPro\two.pth’)

model.state_dict() 返回一个model的完整状态的字典
print(model.state_dict())
打印结果
OrderedDict([('Lone.weight', tensor([[1.1060],
        [1.8362],
        [1.8172],
        [1.3367],
        [1.3266]])), ('Lone.bias', tensor([ 0.0952, -1.8990, -0.1769, -0.6908, -0.9089])), ('Ltwo.weight', tensor([[-0.5234,  0.0282, -0.0736, -0.1586,  0.0362],
        [ 2.3027,  1.9011,  1.6350,  2.2779,  1.6443],
        [ 2.0606,  1.4825,  1.5095,  1.4385,  1.8779],
        [ 2.1344,  1.5597,  1.5797,  2.0561,  1.7084],
        [ 1.9351,  1.7267,  1.4738,  1.5796,  1.4055]])), ('Ltwo.bias', tensor([-0.0395, -0.5296, -0.2373, -0.5782, -1.2260])), ('Lthr.weight', tensor([[-0.2432,  1.7701,  1.8104,  1.8243,  1.6312]])), ('Lthr.bias', tensor([-0.6451]))])

print(model)    model直接打印就只有网络信息      
MyLinnerClass(
  (Lone): Linear(in_features=1, out_features=5, bias=True)
  (Ltwo): Linear(in_features=5, out_features=5, bias=True)
  (Lthr): Linear(in_features=5, out_features=1, bias=True)
)        

其他的细节可以查看原文档
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
这里对上次的回归函数做一下实验

import torch.nn as nn
import torch
import os
import time
import torch.nn.functional as Fun

class MyLinnerClass(nn.Module):
    def __init__(self,inputDim, outputDim):
        super(MyLinnerClass,self).__init__()
        self.Lone = torch.nn.Linear(inputDim,5*inputDim)
        self.Ltwo = torch.nn.Linear(5*inputDim,5*inputDim)
        self.Lthr = torch.nn.Linear(5*inputDim,outputDim)

    def forward(self,x):
        tempOne = self.Lone(x)
        tempOne = Fun.relu(tempOne)
        tempTwo = self.Ltwo(tempOne)
        tempTwo = Fun.relu(tempTwo)
        tempThr = self.Lthr(tempTwo)
        return tempThr

def dataSetCreate(PointNumer=500,div=0.2):
    x = torch.arange(0, PointNumer*div, div)
    print(x.shape)
    x = torch.unsqueeze(x,dim=1)
    print(x.shape)
    noise = (torch.randn(PointNumer)) % 15 / 10
    notCurY = 1.2 * x*x + noise         #在这里其实是不标准的y,故意产生一个有误差的标签
    lableY = 1.2 * x*x
    return x, notCurY, lableY


def main():
    x, notCurY, lableY = dataSetCreate(500,0.2)
    model=MyLinnerClass(1,1)
    epochs=2000
    learningRate = 0.002
    optimiser = torch.optim.Adam(model.parameters(),lr=learningRate)      #优化函数
    certerion = nn.MSELoss()            #loss值

    for oneEpoch in range(epochs):
        oneEpoch += 1
        optimiser.zero_grad()
        train_outY = model(x)
        #loss = certerion(train_outY,lableY)
        loss = certerion(train_outY,notCurY)    #当标签使用的第二个非准确的 loss会有波动
        loss.backward()
        optimiser.step()
        if(oneEpoch%50==0):
            print("loss is :",loss.item())
    #方法一:使用字典的方式保存关键数据然后加载
    state={'net':model.state_dict(),'optimizer':optimiser.state_dict(),'epoch':epochs}
    torch.save(state,'D:\\AllPro\\PytorchPro\\two.pth')
    print("train over")

    #方法二 :只保存模型
    #torch.save(model,'D:\\AllPro\\PytorchPro\\thr.pth')

def evalModel():
    evModel = MyLinnerClass(1,1)
    #方法一:使用字典的方式保存关键数据然后加载
    #使用 checkpoint 取出保存的字典里面的东西
    checkpoint = torch.load('D:\\AllPro\\PytorchPro\\two.pth')
    #针对字典的内容一一赋值到你需要的地方
    evModel.load_state_dict(checkpoint['net'])
    optimiser = torch.optim.Adam(evModel.parameters(), lr=0.001)
    #优化参数赋值
    optimiser.load_state_dict(checkpoint['optimizer'])
    #epoch的轮数赋值
    epoch=checkpoint['epoch']
    print(epoch)
    #eval的目的是固定参数,这样传入的数据推理不会影响下一次
    evModel.eval()

    ##方法二 直接加载模型 但是模型没有完整的训练数据
    #evModel = torch.load('D:\\AllPro\\PytorchPro\\thr.pth')
    #evModel.eval()

    x, notCurY, lableY = dataSetCreate(500,0.2)     #数据输入的shape还是要保持一致
    print("evl Val is",evModel(x))
    print("over")



# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    main()
    evalModel()

使用DataParallel 或者 DistDataParalle 加载的模型

使用DP和DDP加载的模型那么他在load前也是要DDP和DP定义过的
model = nn.DataParallel(net, device_ids=device_ids)

import cv2
import os
import argparse
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from models import DnCNN
from utils import *
from torch.nn.parallel import DistributedDataParallel as DDP
import apex
from apex import amp 

def main():
    #模型定义和设备初始化
    #deviceType='npu:0'
    deviceType='cpu'
    net = DnCNN(channels=1, num_of_layers=17)
    device_ids = [0]
    model = nn.DataParallel(net, device_ids=device_ids)
    pathPth="/home/wangyuanming/DnCNN-PyTorch/netG.pth"
    
    #加载模型
    checkpoint = torch.load(pathPth) 
    #下面使用 [net] 的前提是你在保存模型的时候是用的 state={'net':model.state_dict(),'optimizer':optimiser.state_dict(),'epoch':epochs} 的字典保存的
    model.load_state_dict(checkpoint['net']) 
    print("model get")
    
    #读出图片
    file = "/home/wangyuanming/DnCNN-PyTorch/data/Set12/01.png"
    Img = cv2.imread(file)
    Img = np.float32(Img[:,:,0])/255
    Img = np.expand_dims(Img, 0)
    Img = np.expand_dims(Img, 1)
    ISource = torch.Tensor(Img)
    noise = torch.FloatTensor(ISource.size()).normal_(mean=0, std=15/255.)
    print("data ready")
    ISource ,noise = ISource.to(deviceType),noise.to(deviceType)
    INoisy = ISource + noise
    model = model.to(deviceType)
    model.eval()
    print("model evl")
    with torch.no_grad(): # this can save much memory
        Out = torch.clamp(INoisy-model(INoisy), 0., 1.)
    psnr = batch_PSNR(Out,ISource,1.)
   
    print("psnr {:.3f}".format(psnr))
    
if __name__ == "__main__":
    main()
    

我们可以看看源代码我这里是如何保存的

state={'net':model.state_dict(),'optimizer':optimiser.state_dict(),'epoch':epochs} 

对应的若保存的时候不是字典
保存

 torch.save(model.state_dict(), os.path.join(opt.outf, 'netH.pth'))

加载

pathPth="/home/wangyuanming/DnCNN-PyTorch/logs/netH.pth"
#加载模型
checkpoint = torch.load(pathPth) 
model.load_state_dict(checkpoint) 
print("model get")
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值