pytorch权重的匹配与加载

 model_dict      = model.state_dict()
        pretrained_dict = torch.load(model_path, map_location = device)
        load_key, no_load_key, temp_dict = [], [], {}
        for k, v in pretrained_dict.items():
            if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
                temp_dict[k] = v
                load_key.append(k)
            else:
                no_load_key.append(k)
        model_dict.update(temp_dict)
        model.load_state_dict(model_dict)

这段代码的作用是用于加载一个预训练的 PyTorch 模型的权重,并将这些权重加载到一个已经定义的模型中。代码执行的主要步骤如下:

1. `model_dict` 和 `pretrained_dict` 初始化:首先,代码创建了两个字典 `model_dict` 和 `pretrained_dict` 分别用于存储当前模型的权重和从外部文件加载的预训练模型的权重。

2. 从外部文件加载预训练权重:代码通过调用 `torch.load(model_path, map_location=device)` 从指定路径 `model_path` 加载了一个预训练模型的权重,其中 `map_location=device` 指定了将权重加载到哪个计算设备 (GPU 或 CPU) 上。

3. 循环遍历预训练权重字典:代码通过一个循环遍历 `pretrained_dict` 中的键-值对,其中键 `k` 是预训练权重的名称,值 `v` 是相应的权重张量。

4. 比较权重形状:对于每个键 `k`,代码检查它是否存在于当前模型的权重字典 `model_dict` 中,并且检查权重的形状是否与当前模型中的权重形状相匹配。如果键 `k` 存在于 `model_dict` 中且形状匹配,那么将这个权重添加到临时字典 `temp_dict` 中,并将键 `k` 添加到 `load_key` 列表中,表示这个权重需要加载。

5. 处理不匹配的权重:如果预训练权重的键不在当前模型的权重字典中,或者形状不匹配,那么将这个键添加到 `no_load_key` 列表中,表示这个权重不会被加载。

6. 更新当前模型的权重字典:将临时字典 `temp_dict` 中的权重添加到当前模型的权重字典 `model_dict` 中,以确保只有匹配的权重被加载,不匹配的权重不会破坏当前模型的结构。

7. 加载更新后的权重:最后,使用 `model.load_state_dict(model_dict)` 将更新后的权重加载到当前模型中,从而将预训练的权重应用到当前模型中。

总之,这段代码的主要作用是加载预训练模型的权重,并将这些权重应用到已经定义的模型中,确保权重形状匹配的情况下进行加载。如果形状不匹配,不匹配的权重将被忽略。这通常用于迁移学习,其中你希望使用一个在某个任务上预训练过的模型来初始化你的模型,并在之后微调以适应新的任务。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值