Pytorch中实现只导入部分模型参数的方式

 

今天小编就为大家分享一篇Pytorch中实现只导入部分模型参数的方式,具有很好的参考价值,希望对大家有所帮助。一起跟随微点阅读小编过来看看吧

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

import torch as t

from torch.nn import Module

from torch import nn

from torch.nn import functional as F

class Net(Module):

  def __init__(self):

    super(Net,self).__init__()

    self.conv1 = nn.Conv2d(3,32,3,1)

    self.conv2 = nn.Conv2d(32,3,3,1)

    self.w = nn.Parameter(t.randn(3,10))

    for p in self.children():

      nn.init.xavier_normal_(p.weight.data)

      nn.init.constant_(p.bias.data, 0)

  def forward(self, x):

    out = self.conv1(x)

    out = self.conv2(x)

  

    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))

    out = F.linear(out,weight=self.w)

    return out

然后我们保存这个网络的初始值。

1

2

model = Net()

t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

import torch as t

from torch.nn import Module

from torch import nn

from torch.nn import functional as F

  

  

class Net(Module):

  def __init__(self):

    super(Net, self).__init__()

    self.conv1 = nn.Conv2d(3, 32, 3, 1)

    self.conv2 = nn.Conv2d(32, 3, 3, 1)

    self.conv3 = nn.Conv2d(3,64,3,1)

    self.conv4 = nn.Conv2d(64,32,3,1)

    for p in self.children():

      nn.init.xavier_normal_(p.weight.data)

      nn.init.constant_(p.bias.data, 0)

  

    self.w = nn.Parameter(t.randn(3, 10))

  def forward(self, x):

    out = self.conv1(x)

    out = self.conv2(x)

  

    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))

    out = F.linear(out, weight=self.w)

    return out

我们现在试着导入之前保存的模型参数。

1

2

3

4

5

6

7

8

path = 'xxx.pth'

model = Net()

model.load_state_dict(t.load(path))

  

'''

RuntimeError: Error(s) in loading state_dict for Net:

 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias".

'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

1

2

3

4

5

6

7

8

path = 'xxx.pth'

model = Net()

save_model = t.load(path)

model_dict = model.state_dict()

state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}

print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])

model_dict.update(state_dict)

model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

for k in state_dict.keys():

  print(k)

  

'''

w

conv1.weight

conv1.bias

conv2.weight

conv2.bias

'''

for k in model_dict.keys():

  print(k)

  

'''

w

conv1.weight

conv1.bias

conv2.weight

conv2.bias

conv3.weight

conv3.bias

conv4.weight

conv4.bias

'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考

来源:微点阅读   https://www.weidianyuedu.com

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值