深度学习笔记--权重文件、模型参数和预训练模型的使用

目录

1--打印权重文件参数

2--打印模型参数

3--使用权重文件参数更新模型的参数

4--将模型部分参数固定(不进行梯度下降)

5--参考


1--打印权重文件参数

import torch

weights_files = './test.pt' # 权重文件路径
weights = torch.load(weights_files) # 加载权重文件

for k, v in weights.items(): # key, value
    print(k, v)  # 打印参数名、参数值

2--打印模型参数

'''
Class Model(nn.module):
    #...
'''
# or
# from .xx.yy import Model

model = Model() # 初始化模型
model_dict = model.state_dict() # 模型参数字典
for k, v in model_dict.items(): # key, value
    print(k, v) # 打印参数名、参数值

3--使用权重文件参数更新模型的参数

model = Model() # 初始化模型
model_dict = model.state_dict() # 模型参数

weights_files = './test.pt' # 权重文件 
weights = torch.load(weights_files) # 权重文件参数

# 模型参数和权重参数匹配(可能新模型会作改动)
match_dict = {k: v for k, v in weights.items() if k in model_dict}

# 根据参数匹配,将权重文件的参数加载到模型参数
model_dict.update(match_dict)  # 相当于把预训练网络层的参数更新进来

# 更新模型参数
model.load_state_dict(model_dict)

4--将模型部分参数固定(不进行梯度下降)

model = Model()
for name, param in model.named_parameters():
    # print(name)
    # print(param)
    if name == 'xxx': # 选择参数进行固定
        param.requires_grad = False

或者使用以下方式固定参数:

model = Model()

for i, param in enumerate(model.parameters()):
    if i < 5: # 根据已知的参数顺序,选择参数进行固定
        # print(param)
        param.requires_grad = False

5--参考

参考链接1

  • 3
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值