在预训练网络的基础上,修改部分层得到自己的网络,通常我们需要解决的问题包括:
1. 从预训练的模型加载参数
在PyTorch中,每个Variable数据含有两个flag(requires_grad和volatile)用于指示是否计算此Variable的梯度。设置requires_grad = False,或者设置volatile=True,即可指示不计算此Variable的梯度
1. 从预训练的模型加载参数
2. 对新网络两部分设置不同的学习率,主要训练自己添加的层
PyTorch提供的预训练模型
PyTorch定义了几个常用模型,并且提供了预训练版本:
- AlexNet: AlexNet variant from the “One weird trick” paper.
- VGG: VGG-11, VGG-13, VGG-16, VGG-19 (with and without batch normalization)
- ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
- SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
预训练模型可以通过设置pretrained=True来构建:
eg:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
预训练模型期望的输入是RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行归一化。
加载预训练模型
加载参数可以参考 apaszke推荐的做法 ,即删除与当前model不匹配的key。torch.nn.Module对象有函数static_dict()用于返回包含模块所有状态的字典,包括参数和缓存。键是参数名称或者缓存名称。
函数Module::load_state_dict(state_dict)用state_dict中的状态值更新模块的状态值。static_dict中的键应该和函数static_dict()返回的字典中的键完全一样。
下面给出加载预训练的模型的示例:
vgg16 = models.vgg16(pretrained=True)
pretrained_dict = vgg16.state_dict()
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
不同层设置不同学习率的方法
此部分主要参考 PyTorch教程的Autograd machnics部分在PyTorch中,每个Variable数据含有两个flag(requires_grad和volatile)用于指示是否计算此Variable的梯度。设置requires_grad = False,或者设置volatile=True,即可指示不计算此Variable的梯度
for param in model.parameters():
param.requires_grad = False
注意,在模型测试时,对input_data设置volatile=True,可以节省测试时的显存