pytorch 提取权重_PyTorch中如何加载子模块的权重

该博客探讨了如何在PyTorch中将预训练的子模型B的权重迁移到父模型A中。通过示例代码展示了如何处理模型状态字典,以确保子模型的权重正确地加载到父模型的相应部分。这涉及到修改原始模型字典,以匹配父模型的结构,并使用`load_state_dict`方法更新父模型的状态。
摘要由CSDN通过智能技术生成

假设我们在深度学习模型中有一个这样的需求:主要模型A中包含子模块B,而模型B可以通过一定的方式得到一个预训练的权重,模型A需要利用B模型的权重,在此基础上继续训练。

首先我们到官网上去寻找,PyTorch官网上给出了一些保存和加载模型的示例,可以说非常全面总结了模型保存和加载的方法和主义事项,https://pytorch.org/tutorials/beginner/saving_loading_models.html。但是这里的方案都是针对一个完整模型的保存和加载的,不能满足我们这个需求。

因此需要基于此做一些改进,具体如代码所示:

import torch

import torch.nn as nn

class ModelA(nn.Module):

def __init__(self):

super(ModelA, self).__init__()

self.A = nn.Linear(2, 3)

def forward(self, A):

pass

class ModelB(nn.Module):

def __init__(self):

super(ModelB, self).__init__()

self.model_a = ModelA()

self.A = nn.Linear(2, 3)

def forward(self, x):

pass

print("Model")

modelA = ModelA()

modelA_dict = modelA.state_dict()

print('-' * 80)

for key in sorted(modelA_dict.keys()):

parameter = modelA_dict[key]

print(key)

print(parameter.size())

print(parameter)

modelB = ModelB()

modelB_dict = modelB.state_dict()

print('-'*80)

for key in sorted(modelB_dict.keys()):

print('-'*20)

parameter = modelB_dict[key]

print(type(key), key)

print(parameter.size())

print(parameter)

print('-'*20)

print('-'*80)

pretrained_dict = modelA_dict

model_dict = modelB_dict

pretrained_dict = {'model_a.' + k: v for k, v in pretrained_dict.items() if 'model_a.' + k in model_dict}

model_dict.update(pretrained_dict)

modelB.load_state_dict(model_dict)

modelB_dict = modelB.state_dict()

for key in sorted(modelB_dict.keys()):

parameter = modelB_dict[key]

print(key)

print(parameter.size())

print(parameter)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值