一.引言
对于迁移学习来说,我们经常需要训练一个完整的模型后将其中的部分模块取出用做其他模型进行微调,所以对于参数和模型结构的保存和迁移也至关重要。
二.参数迁移
-
使用
torch.save()
函数将模型A中指定模块的参数保存到文件中。您需要创建一个字典,其中键是模块的名称,值是对应模块的参数。 -
示例代码如下:
import torch import torch.nn as nn # 假设模型A包含两个模块:'module1' 和 'module2' class ModelA(nn.Module): def __init__(self): super(ModelA, self).__init__() self.module1 = nn.Linear(10, 20) self.module2 = nn.Linear(20, 30) # 创建模型A的实例并训练它... # 保存模型A中指定模块的参数 state_dict_to_save = { 'module1': model_a.module1.state_dict(), 'module2': model_a.module2.state_dict() } torch.save(state_dict_to_save, "specified_modules.pth")
可以通过以下步骤将保存的指定模块的参数加载到一个新的模型中:
-
首先,创建一个与模型A相同结构的模型B,包括相同的模块名字和结构。
-
使用
torch.load()
函数读取保存的参数文件 “specified_modules.pth”。这将返回一个字典,其中包含了模型A中指定模块的参数。 -
遍历模型B的模块,将模型A中对应模块的参数值赋给模型B的对应模块。
import torch import torch.nn as nn # 假设模型B的结构与模型A的指定模块相同 class ModelB(nn.Module): def __init__(self): super(ModelB, self).__init__() self.module1 = nn.Linear(10, 20) # 假设这是模型B的一个相同模块 self.module2 = nn.Linear(20, 30) # 假设这是模型B的另一个相同模块 # 创建模型B的实例 model_b = ModelB() # 加载模型A中指定模块的参数到模型B state_dict_loaded = torch.load("specified_modules.pth") model_b.module1.load_state_dict(state_dict_loaded['module1']) model_b.module2.load_state_dict(state_dict_loaded['module2']) # 现在,model_b的参数已经与模型A中指定模块相同
如果只有一个模块需要迁移,那么无需字典也可以实现。