Transfer Learning 迁移学习

Transfer Learning

Note: 本笔记是根据 Pytorch tutorialTransfer Learning for Computer Vision Tutorial 代码学习而来。 此处用于图像上的 transfer。
代码可见: Transfer Learning。 更多详细信息, 请点击Transfer Learning-MachineLearning’s Next Frontier
Transfer learning, 通常加载与训练模型后, 我们冻结model的部分参数, 一般只训练model的最后基层, 这样可以保留整个模型前面对物体特征提取的能力。pre-trained模型一定要与新的dataset有共同点, 比如都是图像分类问题, 这样才能有效地吧pre-trained model里的ability to extract feature 迁移到新的model中。

Two main ways that transfer learning is use:

  1. ConvNet as a fixed feature extractor: Here, you ‘freeze’ the weights of all the parameters in the network except that of the final several layers (aka “the head”, usually fully connected layers). These last layers are replaced with new ones initialized with random weights and only these layers are trained.

  2. Finetuning the ConvNet:Instead of random initializaion, the model is initialized using a pretrained network, after which the training proceeds as usual but with a different dataset. Usually the head (or part of it) is also replaced in the network in case there is a different number of outputs. It is common in this method to set the learning rate to a smaller number. This is done because the network is already trained, and only minor changes are required to “finetune” it to a new dataset.

Sometimes we could combine the above two methods: First we can freeze the feature extractor, and train the head. After that, we could unfreeze the feature extractor (or part of it), set the learning rate to something smaller, and continue training.
@ Transfer Learning

1. Data Preprocess

由于特定的网络对与图片要求不一, 我们应该查看图片的数量分布, 长宽高比, 即aspect ration。对于某些类别数量不够, 可以使用data augmentation。

  1. torchvision.transforms: 实现data augmentation (在training dataset中使用), 并且对数据集进行一定的预处理, 包括改变尺寸大小, 转化为tensor, 以及是否需要继续使用Normalization
from torchvision import transforms

transforms_mean = [0.485, 0.456, 0.406] 
transforms_std = [0.229, 0.224, 0.225]

image_transforms = {
	'train': transforms.Compose([
		transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
		transforms.RandomRotation(degrees=15), 
		transforms.ColorJitter(), 
		transforms.RandomHorizontalFlip(), 
		transforms.CenterCrop(size=224), 
		transforms.ToTensor(), 
		transforms.Normalize(mean=transforms_mean, std=transforms_std),
	]), 
	'val': transforms.Compose([
		transforms.Resize(size=256), 
		transforms.CenterCrop(size=224), 
		transforms.ToTensor(), 
		transforms.Normalize(mean=transforms_mean, std=transforms_std),
	]),
}

  1. 创建data set
  • torchvision.datasets.ImageFolder: 一般性的数据加载器。其中root下的文件目录为

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

	也可以采用文件转换的方式读取文件。`filter`函数可过滤掉`Mac os`系统自带的根目录文件。
		import os 
		base_dir = os.path.dirname(os.path.abspath(__file__))  # get the basic absolute path of this py file
		file_names = os.path.join(base_dir, img_path)  
		image_names = list(filter(lambda x: x.endswith('png') file_names))
torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
  • torch.utils.data.Dataset: pytorch 官方文档关于create iterable dataset . This is one of map-style dataset, we use to overwrite the original class to implement our own __getitem__(), and __len__() function. For example, such a dataset, when accessed dataset[idx], could read the idx-thimage and its corresponding label from a folder on the disk. we can write our own class via parenting Dataset.
  1. Create data loader
    torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=4): Here, we are only discussing on these four parameters.
  • dataset: dataset from which to load the data
  • batch_size: how many samples per batch to load
  • shuffle: sometimes, you need to set it to be True in training data set, but there is no need to use it in validation dataset.
  • num_workers: how many subprocesses to use for data loading.

2. Training The Model

  1. 采取加载原有模型参数
    model_ft = copy.deepcopy(pre_model.state_dict()): 采用state_dict()加载 parameters of pre-trained model。
from torchvision import models

model = models.vgg16(pretrained=True)
for param in model.parameters():
	param.requires_grad = False  # freeze all parameters we do not use
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # here 2 is the 2 classes
  1. 加载部分与训练模型
from torchvision import models
# load resnet50, Net() is own class model
resnet = models.resnet50(pretrained=True)
model = Net(...)

# load parameters
pretrained_dict = resnet.state_dict() 
model_dict = model.state_dict()

# remove keys in pretrainde_dict, which are different from model_dict
pretrained_dict = {k: v for k, v in pretraied_dict.item() if k in model_dict}

# update model_dict
model_dict.update(pretrained_dict)

# load needed state_dict()
model.load_state_dict(model_dict) 

3. Loss Function & Optimizer

微调时, 参数列表应该很长并包含所有参数模型。但是, 当进行特征提取时, 此列表应该很短并且仅包括重塑层的weights 和 bias

# 将模型发送到GPU
model_ft = model_ft.to(device)
# 在此运行中收集要优化/更新的参数。
# 如果我们正在进行微调,我们将更新所有参数。
# 但如果我们正在进行特征提取方法,我们只会更新刚刚初始化的参数,即`requires_grad`的参数为 True。
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
params_to_update = []
for name,param in model_ft.named_parameters():
if param.requires_grad == True: params_to_update.append(param) print("\t",name)
else:
for name,param in model_ft.named_parameters():
if param.requires_grad == True: print("\t",name)
# 观察所有参数都在优化
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

4. Details of Transfer Learning

一般来说, transfer learning 分为一下步骤:

  1. 初始化预训练model
  2. 重组最后一层, 使其具有与新数据集类别数相同的输出数
  3. 为优化算法定义我们想要在训练期间跟新的参数
  4. 运行训练步骤
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Online learning(在线学习)指的是通过互联网或其他远程技术进行教育和学习的方式。它已经成为当今社会中非常普遍的学习方式,尤其是在全球范围内的疫情期间,许多学校和机构都转向了在线学习模式。学生可以通过在线学习平台获得课程内容、参与讨论、提交作业和进行考试。 迁移学习Transfer learning)是指将已经学习过的知识或技能应用到一个新的领域或任务中。这种学习方式允许个体在面对新的挑战或情境时能够更快更有效地应对,因为他们可以将已有的经验和知识转移到新的情境中。迁移学习可以帮助个体更好地适应变化、提高适应能力和解决问题的能力。 Online learning迁移学习之间存在一定的联系。通过在线学习,学生可以在一个虚拟的环境中学习各种知识和技能,这些学习成果可以在日常生活或工作中进行迁移应用。例如,学习一门外语的技能可以在旅行时得到应用;学习数据分析的方法可以在工作中帮助提高工作效率。因此,通过在线学习获得的知识和技能都可以通过迁移学习的方式应用到不同的场景中,并帮助个体更好地适应变化,提高自身的综合素质。 总体来说,Online learning迁移学习都是现代社会重要的学习方式和策略,它们为个体提供了更多的学习机会和技能应用场景,有助于个体提高自身的学习能力和适应能力。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值