Note: 本笔记是根据 Pytorch tutorial
中Transfer Learning for Computer Vision Tutorial
代码学习而来。 此处用于图像上的 transfer。
代码可见: Transfer Learning。 更多详细信息, 请点击Transfer Learning-MachineLearning’s Next Frontier
Transfer learning, 通常加载与训练模型后, 我们冻结model的部分参数, 一般只训练model的最后基层, 这样可以保留整个模型前面对物体特征提取的能力。pre-trained模型一定要与新的dataset有共同点, 比如都是图像分类问题, 这样才能有效地吧pre-trained model里的ability to extract feature 迁移到新的model中。
Two main ways that transfer learning is use:
ConvNet as a fixed feature extractor: Here, you ‘freeze’ the weights of all the parameters in the network except that of the final several layers (aka “the head”, usually fully connected layers). These last layers are replaced with new ones initialized with random weights and only these layers are trained.
Finetuning the ConvNet:Instead of random initializaion, the model is initialized using a pretrained network, after which the training proceeds as usual but with a different dataset. Usually the head (or part of it) is also replaced in the network in case there is a different number of outputs. It is common in this method to set the learning rate to a smaller number. This is done because the network is already trained, and only minor changes are required to “finetune” it to a new dataset.
Sometimes we could combine the above two methods: First we can freeze the feature extractor, and train the head. After that, we could unfreeze the feature extractor (or part of it), set the learning rate to something smaller, and continue training.
1. Data Preprocess
由于特定的网络对与图片要求不一, 我们应该查看图片的数量分布, 长宽高比, 即aspect ration。对于某些类别数量不够, 可以使用data augmentation。
: 实现data augmentation (在training dataset中使用), 并且对数据集进行一定的预处理, 包括改变尺寸大小, 转化为tensor, 以及是否需要继续使用Normalization
from torchvision import transforms
transforms_mean = [0.485, 0.456, 0.406]
transforms_std = [0.229, 0.224, 0.225]
image_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.Normalize(mean=transforms_mean, std=transforms_std),
'val': transforms.Compose([
transforms.Normalize(mean=transforms_mean, std=transforms_std),
- 创建data set
: 一般性的数据加载器。其中root
也可以采用文件转换的方式读取文件。`filter`函数可过滤掉`Mac os`系统自带的根目录文件。
import os
base_dir = os.path.dirname(os.path.abspath(__file__)) # get the basic absolute path of this py file
file_names = os.path.join(base_dir, img_path)
image_names = list(filter(lambda x: x.endswith('png') file_names))
torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
: pytorch 官方文档关于create iterable dataset . This is one of map-style dataset, we use to overwrite the original class to implement our own__getitem__()
, and__len__()
function. For example, such a dataset, when accesseddataset[idx]
, could read theidx-th
image and its corresponding label from a folder on the disk. we can write our own class via parentingDataset
- Create data loader
torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=4)
: Here, we are only discussing on these four parameters.
: dataset from which to load the databatch_size
: how many samples per batch to loadshuffle
: sometimes, you need to set it to be True in training data set, but there is no need to use it in validation dataset.num_workers
: how many subprocesses to use for data loading.
2. Training The Model
- 采取加载原有模型参数
model_ft = copy.deepcopy(pre_model.state_dict())
: 采用state_dict()
加载 parameters of pre-trained model。
from torchvision import models
model = models.vgg16(pretrained=True)
for param in model.parameters():
param.requires_grad = False # freeze all parameters we do not use
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2) # here 2 is the 2 classes
- 加载部分与训练模型
from torchvision import models
# load resnet50, Net() is own class model
resnet = models.resnet50(pretrained=True)
model = Net(...)
# load parameters
pretrained_dict = resnet.state_dict()
model_dict = model.state_dict()
# remove keys in pretrainde_dict, which are different from model_dict
pretrained_dict = {k: v for k, v in pretraied_dict.item() if k in model_dict}
# update model_dict
# load needed state_dict()
3. Loss Function & Optimizer
微调时, 参数列表应该很长并包含所有参数模型。但是, 当进行特征提取时, 此列表应该很短并且仅包括重塑层的weights 和 bias
# 将模型发送到GPU
model_ft = model_ft.to(device)
# 在此运行中收集要优化/更新的参数。
# 如果我们正在进行微调,我们将更新所有参数。
# 但如果我们正在进行特征提取方法,我们只会更新刚刚初始化的参数,即`requires_grad`的参数为 True。
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
params_to_update = []
for name,param in model_ft.named_parameters():
if param.requires_grad == True: params_to_update.append(param) print("\t",name)
for name,param in model_ft.named_parameters():
if param.requires_grad == True: print("\t",name)
# 观察所有参数都在优化
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
4. Details of Transfer Learning
一般来说, transfer learning 分为一下步骤:
- 初始化预训练model
- 重组最后一层, 使其具有与新数据集类别数相同的输出数
- 为优化算法定义我们想要在训练期间跟新的参数
- 运行训练步骤