pytorch加载预训练的模型
解决方案
-
其实很简单,一行代码即可
self.resnet = models.resnet18(models.ResNet18_Weights.DEFAULT)
-
可以使用eval()方法,来将模型参数固定
self.resnet = models.resnet18(models.ResNet18_Weights.DEFAULT).eval()
-
导入之后,模型就相当于一个普通的神经网络的层,可以像linear,relu一样使用它,像这样
def forward(self, x:torch.Tensor): h = self.resnet(x) o = self.linear(h) return o
注意
- models是
torchvision
模块中的,而不是torch
模块中的,因此需要有
import torchvision.models as models
传入的参数ResNet18_Weights
也是models中的,记得加上前缀 model.Resnet()
中的参数也可以填pretrained=True
,但是运行时会有一个warning,显示这个参数deprecated,建议换成新的参数。官网上也有类似的说法。此外,官网还提供了不同版本的参数选择。