Pytorch加载部分预训练模型的参数

问题背景

假设我有一个已训练好的Model1,并已保存它的参数为.pth格式,我有一个与Model1结构完全相同的模型Model2,我希望Model2加载Model1中与特征提取有关的模块的参数,其他模块的参数随机初始化。

应用场景为在K折交叉验证时,我希望从第二折开始的模型加载第一折训练模型的部分参数,并在此基础上微调,从而减少训练轮数。

解决方法

加载保存的第一折训练好的模型参数,因为我保存时是多GPU,加载时需要多GPU模型参数去掉module:

state_dict = torch.load("./Fold_1_checkpoint_4.pth",map_location=torch.device('cpu'))
# 多GPU 模型参数去掉 module
from collections import OrderedDict
state_dict_new = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]  # 去掉 `module.`
    # name = k.replace(“module.", "")
    state_dict_new[name] = v

打印出来看看:

OrderedDict([('LD1.conv.weight',
              tensor([[[ 0.0047,  0.0013,  0.0026,  0.0025,  0.0060, -0.0055,  0.0024,
                        -0.0022,  0.0052, -0.0039, -0.0015,  0.0049,  0.0861,  0.0095,
                         0.0032,  0.0033,  0.0037,  0.0008,  0.0023,  0.0080,  0.0045,
                        -0.0023,  0.0097,  0.0034,  0.0160]]])),
             ('LD1.conv.bias', tensor([0.0101])),
             ('position_embedder1.position_embedding',
              tensor([[-0.1612, -0.2091, -0.9842,  ...,  0.7624, -0.7286, -0.6549],
                      [-0.2659,  0.0225, -1.1011,  ...,  1.3899, -0.6597, -0.7806],
                      [-0.3426, -0.2478, -1.0329,  ...,  1.2922, -0.8009, -0.8948],
                      ...,
                      [ 1.4002,  0.0854,  0.4595,  ...,  0.8604, -0.6667, -0.3106],
                      [-0.2322,  0.1524,  0.0809,  ..., -0.3961,  1.2030, -0.3428],
                      [-0.7272,  2.0944, -0.7098,  ...,  0.2923,  0.1994, -0.2035]])),
             ······
             ('encoder1.attn_layers.0.attention.gen1.weight',
              tensor([[-0.2019,  0.2232,  0.2444,  ..., -0.0791,  0.1458,  0.2732],
                      [-0.3740,  0.0378, -0.1596,  ..., -0.2621,  0.2512,  0.1971],
                      [ 0.0941, -0.2160, -0.1651,  ...,  0.1103,  0.2440, -0.8079],
                      ...,
                      [-0.0320, -0.2835,  0.7415,  ...,  0.1435, -0.5525, -0.2343],
                      [-0.2918, -0.4909,  0.0787,  ..., -0.2581, -0.2801,  0.5388],
                      [-0.3744, -0.8789, -0.3549,  ..., -0.0468, -0.1703, -0.1652]])),
             ('encoder1.attn_layers.0.attention.gen1.bias',
              tensor([ 0.1687, -0.8035, -0.0674, -0.4795, -0.2964, -0.5702, -0.9225, -0.0856,
                       1.4210, -0.9099, -1.3605, -0.8256, -1.6564, -1.5394, -0.7184,  0.4084,
                      -0.6281, -2.2652, -0.7902, -0.9841, -0.4321, -0.4909, -1.1394, -1.2694,
                      -0.0219, -0.3581,  0.5682, -1.2015, -0.9408, -1.4014, -1.2602, -1.3793,
                      -0.8697, -2.1542, -0.7590, -1.7175,  0.0472,  0.0683,  0.4771, -1.7388,
                      -1.7942, -0.2653, -0.2880,  0.1545, -1.4922, -0.3756,  0.5862, -1.5961,
                      -0.2772,  0.9524, -0.1537, -0.4896, -0.2866, -0.9232,  0.2729,  0.0779,
                      -1.7396,  0.0744, -0.9369,  0.3411, -0.2821, -0.9431, -1.1710,  0.1081])),
              ······
             ('gate1.0.weight',
              tensor([[ 0.4562, -0.0911,  0.2505,  ...,  0.4371,  0.1663, -0.4728],
                      [ 0.1472,  0.1287,  0.0289,  ...,  0.0881,  0.0862, -0.2207],
                      [ 0.5243,  0.0716,  0.4541,  ..., -0.4321,  0.6488,  0.7909],
                      ...,
                      [-0.0867,  0.2686,  0.1457,  ...,  0.0756,  0.4644, -0.4607],
                      [-0.1575,  0.3019,  0.1621,  ..., -0.1950,  0.2406, -0.2090],
                      [-0.0231, -0.2043, -0.0222,  ...,  0.1126,  0.2387, -0.7590]])),
             ('gate1.0.bias',
              tensor([-0.5651, -0.6304, -0.6943, -0.5891, -0.2952, -0.3897, -0.6592, -0.3604,
                       0.1627, -0.4173, -0.3456, -0.6569, -0.0337, -0.1433, -0.3954, -0.0141])),
             ······
             ('Linear_res2.bias',
              tensor([-0.0314, -0.0065,  0.0268, -0.0321, -0.0738, -0.0063, -0.0531, -0.0615,
                      -0.0552, -0.0357, -0.0639, -0.0893, -0.0361, -0.0736, -0.0347, -0.0330,
                      -0.0759, -0.0828, -0.0665, -0.0439, -0.0652, -0.0718, -0.0231, -0.0297,
                      -0.0448, -0.0408, -0.0181, -0.0379, -0.0274, -0.0526, -0.0139, -0.0404,
                      -0.0284, -0.0496, -0.0515, -0.0054, -0.0704, -0.0666, -0.0385, -0.0613,
                      -0.0471, -0.0886, -0.0398, -0.0616, -0.0304, -0.0558, -0.0301, -0.0728,
                      -0.0869, -0.0409, -0.0514, -0.0737, -0.0510, -0.1048, -0.0555, -0.0530,
                      -0.0721, -0.0315, -0.0070, -0.0687, -0.0707, -0.0403, -0.0611, -0.0340,
                      -0.0935, -0.0339, -0.0462, -0.0842, -0.0516, -0.0445, -0.0364, -0.0748]))])

在此案例中,模型特征提取部分为encoder开头和gate开头的模块,因此需要过滤掉其他不需要迁移的模块:

pretrained_dict = {k: v for k, v in state_dict_new.items() if k.split('.')[0][:-1] in ['encoder','gate']}

初始化Model2

model2 = Model(args).cuda()
model_dict = model2.state_dict()

更新model_dict中需要迁移的部分并导回模型:

model_dict.update(pretrained_dict)
model2.load_state_dict(model_dict)

至此,model2中就加载了model1中与特征提取相关的模块参数,即可在此基础上微调。

参考文献

Pytorch如何加载部分预训练模型的参数
pytorch如何使模型只更新一部分参数 pytorch加载模型部分参数 转载
pytorch加载多GPU模型和单GPU模型

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值