pytorch文档阅读(五)如何保存、加载网络模型

1.网络的保存

torch.save()有两种方法

1)仅保存网络参数

torch.save(net.state_dict(), 'net_params.pkl')

2)保存整个网络结构

 torch.save(net, 'net.pkl')

2.网络的加载

1)仅加载参数

model_object.load_state_dict(torch.load('net_params.pkl')) 

2)加载整个模型

model = torch.load('net.pkl') 

两种方法在载入模型时都需要有预设的网络结构,例如下边代码,否则会提示找不到相应的module

#加载整个网络

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool1 = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.pool2 = nn.MaxPool2d(3, 2)
        self.fc1 = nn.Linear(1024, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x

net = torch.load("TestSave.pkl")#加载整个模型时直接用这句就可以实例化网络,并且把CUDA上运行这个属性也继承了过来
#只加载网络参数

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet,self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.pool1 = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64, 64, 5)
        self.pool2 = nn.MaxPool2d(3, 2)
        self.fc1 = nn.Linear(1024, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x

net = AlexNet()#只加载网络参数的时候需要自行实例化网络
net.cuda()#并设置网络运行在cpu还是gpu上

net.load_state_dict(torch.load('net_params.pkl'))#再加载网络的参数

注意:

1.只加载网络参数的速度比加载整个网络快得多

2.pth、pkl格式效果相同,ckpt是tensorflow的格式

参考链接:

https://www.jb51.net/article/139102.htm

https://www.jianshu.com/p/0eda629e4007

  • 11
    点赞
  • 75
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在使用PyTorch框架进行DQN算法训练时,可以使用PyTorch提供的模型保存方法来保存DQN模型。具体步骤如下: 1. 首先,定义并构建DQN模型。根据引用提到的,PyTorch框架可以用于实现DQN算法,因此可以使用PyTorch提供的神经网络模块来构建Q网络。 2. 在训练过程中,可以选择在每个训练轮次或者指定步骤时,使用PyTorch提供的模型保存方法将当前的DQN模型保存到硬盘上的指定位置。 3. 保存模型时,可以指定保存的文件名和路径,以便在需要的时候可以方便地模型。可以使用PyTorch提供的`torch.save()`函数来保存模型。 4. 模型时,可以使用PyTorch提供的`torch.load()`函数来保存模型文件。 由于引用中提到了PyTorch框架和DQN算法的结合,可以推断出在使用PyTorch进行DQN算法训练时,可以通过PyTorch提供的模型保存功能来保存DQN模型。但具体的保存代码和细节需要参考PyTorch官方文档或者相关教程。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [强化学习算法Pytorch实现全家桶](https://download.csdn.net/download/weixin_44564247/19729484)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* *3* [【深度强化学习】(1) DQN 模型解析,附Pytorch完整代码](https://blog.csdn.net/dgvv4/article/details/129447669)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值