关于pytorch直接加载resnet50模型及模型参数

1.由于与resnet50的分类数不一样,所以在调用时,要使用num_classes=分类数

model = torchvision.models.resnet50(pretrained=True,num_classes=5000)   #pretrained=True 既要加载网络模型结构,又要加载模型参数

如果需要加载模型本身的参数,需要使用pretrained=True

2.由于最后一层的分类数不一样,所以最后一层的参数数目也就不一样,所以在加载模型参数时要去掉最后一层

def _resnet(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> ResNet:
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        
        for k in list(state_dict.keys()):  #固定遍历对象
            print(k)
            if k == "fc.weight" or k == "fc.bias":
                state_dict.pop(k)  #删除最后一层的模型参数
         
        
        model.load_state_dict(state_dict,strict=False)  #非严格加载模型参数
    return model

由于字典中的元素是不固定的,所以在遍历的时候需要使用list,将其变为列表,这样元素位置就固定了,才可以进行后面的pop操作。

 由于没有加载最后一层,所以参数中需要加上strict=False

3.总结一下如何调用pytorch框架中已有的模型及其参数(如果是分类器,且最后一层分类数不一样)

a.实例化model

 b.点击resnet50,到源文件中去修改去除最后一层参数

  • 6
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch实现ResNet50模型的代码如下所示: ```python import torch import torchvision.models as models # 加载预训练的ResNet50模型 model = models.resnet50(pretrained=True) # 替换最后一层全连接层的输出类别数 num_classes = 1000 # 假设分类数为1000 model.fc = torch.nn.Linear(model.fc.in_features, num_classes) # 将模型设置为评估模式 model.eval() ``` 在这段代码中,我们首先导入了`torch`和`torchvision.models`模块。然后,我们使用`models.resnet50(pretrained=True)`加载了预训练的ResNet50模型。接下来,我们替换了模型的最后一层全连接层,将其输出类别数设置为我们需要的分类数。最后,我们将模型设置为评估模式。 请注意,这段代码中没有使用到引用\[1\]、\[2\]和\[3\]中的具体内容,因为这些内容与问题的回答无关。 #### 引用[.reference_title] - *1* [关于pytorch直接加载resnet50模型模型参数](https://blog.csdn.net/eye123456789/article/details/124948949)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [pytorch实现resnet50(训练+测试+模型转换)](https://blog.csdn.net/gm_Ergou/article/details/118419795)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值