从本地地址直接加载预训练模型
在本地,有几种方式可以避免下载直接加载预训练模型:
https://www.cnblogs.com/ywheunji/p/10605614.html
- 直接修改源码,改为本地地址
- 把模型权重下载至torch的缓存文件夹
当需要迁移学习时加载预训练模型
此时,我的模型和预训练的模型有些结构可能不一致,需要多出一步,筛选掉预训练模型中与我的结构不一致的部分:
https://blog.csdn.net/VictoriaW/article/details/72821329
https://www.94e.cn/info/4270
vgg16 = models.vgg16(pretrained=True)
pretrained_dict = vgg16.state_dict()
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
torch.load
从文件中加载一个用torch.save()保存的对象。
当我是用自己从网上下载的模型的时候要用torch.load:
pretrain = torch