Pytorch 预训练模型下载和加载

PyTorch 加载和下载预训练模型可参考:pytorch预训练模型的下载地址以及解决下载速度慢的方法

- 下载地址

常用预训练模型在这里面:https://github.com/pytorch/vision/tree/master/torchvision/models

但是上述网址只有常见的 backbone (vgg, resnet, densenet, alexnet),在 GitHub 上,还找到了一个项目,提供 NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN 等预训练模型的下载:https://github.com/Cadene/pretrained-models.pytorch

具体下载位置是:https://data.lip6.fr/cadene/pretrainedmodels/

- 加载预训练模型

一般使用的是使用 model.load_state_dict() 函数。

model_urls = {  'resnet50': '/home/huihua/NewDisk1/pretrain_parameter/resnet50-19c8e357.pth',}
def resnet50(pretrained=False, **kwargs):
	model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
	if pretrained:
		model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
	return model

此时它会到指定的网站下载预训练模型到本地缓存中,本地缓存的位置(Linux系统)一般在:

.cache/torch/checkpoints

PyTorch 在加载模型时候首先检查本地缓存是否已经存在预训练模型,所以在本地缓存汇总预先放入已经下载的模型可快速加载模型。

如果需要更改预训练模型的位置,可以在文件开头加入:

os.environ['TORCH_HOME']= './pretrained_models/'

pretrained_models 文件夹下新建一个 checkpoints 文件夹并把预训练模型放入即可。

- 参考

  1. pytorch预训练模型下载URL及加载调用方法
  2. pytorch学习笔记之加载预训练模型
  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值