迁移学习(Transfer Learning)简单来说就是使用一些已经训练好的模型迁移到类似的新的问题进行使用,而不必对新问题重新建模,从头训练和优化参数。这些训练好的模型同时包含了优化好的参数,在使用的时候只需要做微调就可以应用到新问题。
比如我现在需要做一个图像二分类的任务,我可以使用已经训练好的resnet_152模型来进行微调。
首先应该将我们的输入数据处理成resnet同样的输入尺寸224*224
然后通过torchvision下载我们需要的模型:
from torchvision import models
model = models.resnet152(pretrained=True) # 参数pretrained=True表示需要下载预训练好的参数
冻结模型的网络结构,使参数不更新
for param in model.parameters():
param.requires_grad = False
根据我们需要的输出更改网络结构
model.fc = nn.Linear(model.fc.in_features, 2) #这里模型最后的输出类别为2
后面再正常对我们的模型进行训练