以下代码等价,但第一种慢,第二种快:
import torchvision.models as models
self.resnet = models.resnet18(pretrained=True) # 联网下载,慢
self.model = models.resnet18(pretrained=False)
state_dict = torch.load('src/model/resnet18-5c106cde.pth') # 自己从网上下载.pth,快
self.model.load_state_dict(state_dict) # 再把读出来的参数放进没有参数的模型
当pretrained=True,才会联网下载模型,否则很快,仅得到一个没训练过的模型。
.pth文件或者state_dict变量:模型参数,里面是模型每一层具体的浮点数
model:模型,不含参数
model和.pth如果是对应的,就可以用model.load_state_dict加载。注意这条语句是在模型上直接修改,不应写成model = model.load_state_dict。
所以我们可以自己在浏览器下载模型,然后加载进去。那么去哪里下载呢?Ctrl+函数打开源码自己就可以找到。