pytorch保存模型pth_详解Pytorch中的网络构造,模型save和load,.pth权重文件解析

本文详细介绍了PyTorch中模型的保存与加载,特别是.pth权重文件的处理。在进行微调时,若使用旧.pth文件可能会导致学习率不变。解决办法是仅保存model部分。PyTorch模型实例化后维护了8个字典,用于网络前向、反向和序列化。在实例化网络时,必须调用nn.Module的构造函数。模型的保存和加载涉及state_dict,通过torch.save和torch.load操作。load_state_dict函数用于加载模型参数,strict参数控制是否严格匹配。此外,还讨论了跨设备保存和加载、DataParallel模型的序列化等场景。
摘要由CSDN通过智能技术生成

pytorch最后的权重文件是.pth格式的。

经常遇到的问题:

进行finutune时,改配置文件中的学习率,发现程序跑起来后竟然保持了以前的学习率, 并没有使用新的学习率。

原因:

首先查看.pth文件中的内容,我们发现它其实是一个字典格式的文件:

其中保存了optimizer和scheduler,所以再次加载此文件时会使用之前的学习率。

我们只需要权重,也就是model部分,将其导出就可以了

import torch

original= torch.load('path/to/your/checkpoint.pth')new = {"model": original["model"]}

torch.save(new, 'path/to/new/checkpoint.pth')

背景

在PyTroch框架中,如果要自定义一个Net(网络,或者model,在本文中,model和Net拥有同样的意思),通常需要继承自nn.Module然后实现自己的layer。比如,在下面的示例中,gemfield(tiande亦有贡献)使用Pytorch实现了一个Net(可以看到其父类为nn.Module):

import torch

import torch.nnasnn

import torch.nn.functionalasFclassCivilNet(nn.Module):

def __init__(self):

super(CivilNet, self).__init__()

self.conv1= nn.Conv2d(3, 6, 5)

self.pool= nn.MaxPool2d(2, 2)

self.conv2= nn.Conv2d(6, 16, 5)

self.fc1= nn.Linear(16 * 5 * 5, 120)

self.fc2= nn.Linear(120, 84)

self.fc3= nn.Linear(84, 10)

self.gemfield= "gemfield.org"self.syszux= torch.zeros([1,1])

def forward(self, x):

x=self.pool(F.relu(self.conv1(x)))

x=self.pool(F.relu(self.conv2(x)))

x= x.view(-1, 16 * 5 * 5)

x=F.relu(self.fc1(x))

x=F.relu(self.fc2(x))

x=self.fc3(x)return x

这就带来了一系列的问题:

1,为什么要继承自nn.Module?

2,网络的各个layer或者module为什么要直接定义在构造函数中,而不能(比方说)放在构造函数中的一个list里?

3,forward函数什么时候会被调用?为什么要使用net(input)而不是net.forward(input)来做前向呢?

4,保存模型时,保存的究竟是什么?

5,重新载入一个pth模型时,究竟发生了什么?

你肯定要问了,为什么没说到反向?因为反向是optimizer和tensor的grad共同完成的,本文只讨论Net部分,这一系列文章的后续部分会讨论反向。

CivilNet的实例化

一个Net,也就是继承自nn.Module的类,当实例化后,本质上就是维护了以下8个字典(OrderedDict):

_parameters

_buffers

_backward_hooks

_forward_hooks

_forward_pre_hooks

_state_dict_hooks

_load_state_dict_pre_hooks

_modules

这8个字典用于网络的前向、反向、序列化、反序列化中。

因此,当实例化你定义的Net(nn.Module的子类)时,要确保父类的构造函数首先被调用,这样才能确保上述8个OrderedDict被create出来,否则,后续任何的初始化操作将抛出类似这样的异常:cannot assign module before Module.__init__() call。

对于前述的CivilNet而言,当CivilNet被实例化后,CivilNet本身维护了这8个OrderedDict,更重要的是,CivilNet中的conv1和conv2(类型为nn.modules.conv.Conv2d)、pool(类型为nn.modules.pooling.MaxPoo

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值