目录
1--打印权重文件参数
import torch
weights_files = './test.pt' # 权重文件路径
weights = torch.load(weights_files) # 加载权重文件
for k, v in weights.items(): # key, value
print(k, v) # 打印参数名、参数值
2--打印模型参数
'''
Class Model(nn.module):
#...
'''
# or
# from .xx.yy import Model
model = Model() # 初始化模型
model_dict = model.state_dict() # 模型参数字典
for k, v in model_dict.items(): # key, value
print(k, v) # 打印参数名、参数值
3--使用权重文件参数更新模型的参数
model = Model() # 初始化模型
model_dict = model.state_dict() # 模型参数
weights_files = './test.pt' # 权重文件
weights = torch.load(weights_files) # 权重文件参数
# 模型参数和权重参数匹配(可能新模型会作改动)
match_dict = {k: v for k, v in weights.items() if k in model_dict}
# 根据参数匹配,将权重文件的参数加载到模型参数
model_dict.update(match_dict) # 相当于把预训练网络层的参数更新进来
# 更新模型参数
model.load_state_dict(model_dict)
4--将模型部分参数固定(不进行梯度下降)
model = Model()
for name, param in model.named_parameters():
# print(name)
# print(param)
if name == 'xxx': # 选择参数进行固定
param.requires_grad = False
或者使用以下方式固定参数:
model = Model()
for i, param in enumerate(model.parameters()):
if i < 5: # 根据已知的参数顺序,选择参数进行固定
# print(param)
param.requires_grad = False