上一篇我们介绍了迁移学习的核心思想和流程,我们介绍一个实例来加深理解。
传送门:迁移学习概述
获取预训练模型
pytorch和tensorflow都封装了很多预训练模型。
pytorch通过工具包torchvision.models模块获取,主要包括AlexNet、VGG系列、
ResNet系列、SqueezeNet和DenseNet等,通过设置参数pretrained=True即可获取。而Tensorflow内置在keras.application里面,当然,也可以通过TensorFlowHub网站自行下载。
from tensorflow.keras.applications import vgg16,resnet
from torchvision.models import AlexNet,VGG,ResNet
from torchvision.models import SqueezeNet,DenseNet
一个实例
下面通过一个例子对迁移学习有个感性的认识。预训练模型采用retnet18网络,一共分为八大步骤。
注:代码均来源于《深入浅出Embedding》第三章
1.导入模块
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torchvision.datasets import ImageFolder
from datetime import datetime
2.加载数据
加载相关数据集,首次下载需要将download设置为True,此外,还对数据做了一些预处理,标准化、图片裁剪等。