迁移学习在图像分类领域非常常见,利用在超大数据集上训练得到的网络权重,迁移到自己的数据上进行训练可以节约大量的训练时间,降低欠拟合/过拟合的风险。
如果用原生网络进行迁移学习非常简单,其核心是
model.load_state_dict()
以Pytorch中官方提供的Resnet加载预训练权重的代码为例:
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
但很多时候,我们可能需要对原生网络做一些修改,比如自定义地增加一些网络层,改变某些网络层的结构等等,这时候如果直接像上面那样直接加载就会报错。
AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
因为网络结构发生了变化,预训练权重是以字典的形式存储的,它会和当前网络结构的字典对应不上。
因此,我们需要通过加载部分预训练权重的方式来进行初始化。