pytroch 如何保存与导入预训练模型

1 如何保存预训练模型

1.1 以字典形式 保存模型参数

torch.save可以保存我们的模型的部分参数 如下图。

//n_hidden,n_layers为超参.net.state_dict()为模型参数
class model(nn.Module):
       def __init__(self, **kwargs):
           def __init__(self, dataset, embedding):
           self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, dropout=drop_prob, batch_first=True)
           self.fc = nn.Linear(n_hidden, len(self.chars))
           ... ... 
       def forward():
           ... ...
model_name = 'rnn_x_epoch.net'
checkpoint = {'n_hidden': net.n_hidden,
              'n_layers': net.n_layers,
              'state_dict': net.state_dict()}
 with open(model_name, 'wb') as f:
    torch.save(checkpoint, f);

1.2 直接保存整个模型

torch.save也可以保存我们的整个模型 如下图。

torch.save(model, './path')

具体选择哪一种依照自己的需要,如果只取模型中的一部分,第一种感觉方便一些。如果希望以后直接加载现成的模型。第二种可能方便一些。

2 如何加载自己训练的预训练模型

2.1 加载以字典形式保存的模型参数

model_name = 'rnn_x_epoch.net'
model=torch.load(model_name)
print(type(model))
print('____')
for i in model:
    print(i)
 print('____')
   

输出结果如下

//n_hidden,n_layers为超参.net.state_dict()为模型参数
<class 'dict'>
____
n_hidden
n_layers
state_dict
____
lstm.weight_ih_l0
lstm.weight_hh_l0
lstm.bias_ih_l0
lstm.bias_hh_l0
lstm.weight_ih_l1
lstm.weight_hh_l1
lstm.bias_ih_l1
lstm.bias_hh_l1
fc.weight
fc.bias

可以看到 按照第一种方式,是以字典方式将文件存储进了文件,那么我们怎么将这个里面训练好的网络加载进新的模型呢?

2.1.1 创建新的模型对象
//按照需要 创建一个你希望的新模型
class Net_1(nn.Module):
       def __init__(self, **kwargs):
           def __init__(self, dataset, embedding):
           self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, dropout=drop_prob, batch_first=True)
           self.fc1 = nn.Linear(n_hidden, len(self.chars))
           ... ... 
       def forward():
           ... ... 
2.1.1 创建新的模型对象
// model1_dict是一个存着net_1所有参数的字典,里面的数据是随机初始化的
 //model是我们的预训练模型,model['state_dict']存储着我们需要的参数
net_1=Net_1(*kwargs,**kwargs)
model1_dict = net_1.state_dict()
new_state_dict = {k:v for k,v in model['state_dict'].items() if k in model1_dict}

为什么需要这个if k in model1_dict语句呢?因为我们有时候只需要部分加载 而不是一股脑全部放上去,那么这个语句是怎么实现这个功能的呢?
我们来看看model1_dict的结构

for i in model1_dict:
  print(i)
lstm.weight_ih_l0
lstm.weight_hh_l0
lstm.bias_ih_l0
lstm.bias_hh_l0
lstm.weight_ih_l1
lstm.weight_hh_l1
lstm.bias_ih_l1
lstm.bias_hh_l1
fc1.weight
fc1.bias

在经过new_state_dict = {k:v for k,v in model[‘state_dict’].items() if k in model1_dict}语句以后 新的模型继承了lstm层的参数,但是没有继承线性层的参数,因为原模型线性层的名字为fc ,而新模型的线性层的参数为fc1。

这就需要我们在创建新的模型对象的时候,将希望保存的层,与原层有相同的名字,而不希望的保存的层,有不同的名字。
现在new_state_dict 中包含我们需要更新的所有参数

model1_dict.update(new_state_dict)	#更新参数
net_1.load_state_dict(model1_dict) #加载参数
  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值