【动手学】37 微调_代码

%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
#@save               
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip',
                         'fba480ffa8aa7e0febbb511d181409f899b9baa5')
​
data_dir = d2l.download_extract('hotdog')
​
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'))

图像的大小和纵横比各有不同

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i-1][0] for i in range(8)]
d2l.show_images(hotdogs +not_hotdogs,2,8,scale=1.4)
输出:
array([<AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>,
       <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>, <AxesSubplot:>],
      dtype=object)

数据增广

normalize = torchvision.transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomResizedCrop(224),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    normalize
])
​
test_augs = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    normalize
])

 

定义和初始化模型

pretrained_net = torchvision.models.resnet18(pretrained=True)  #不仅将模型拿过来,而且将模型参数拿过来; 这是pretrain
pretrained_net.fc
输出:
Linear(in_features=512, out_features=1000, bias=True)
finetune_net = torchvision.models.resnet18(pretrained=True)        #下载预训练模型
finetune_net.fc = nn.Linear(finetune_net.fc.in_features,2)           #将最后全连接层,分类问题改成2分类问题
nn.init.xavier_normal_(finetune_net.fc.weight)                  #只对最后一层weight做随机初始化      
          Parameter containing:
tensor([[ 0.0026, -0.1623, -0.0596,  ..., -0.0453, -0.0063, -0.1450],
        [-0.0554,  0.0253, -0.0593,  ...,  0.0768,  0.1451,  0.0037]],
       requires_grad=True)
#如果param_graop =True ,输出层的模型参数将使用十倍的学习率
def train_fine_tuning(net,learning_rate,batch_size=128,num_epochs=5,param_group = True):
    train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
    os.path.join(data_dir,'train'),transform=train_augs),
    batch_size=batch_size,shuffle=True)
    
    test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(
    os.path.join(data_dir,'test'),transform=test_augs),
    batch_size=batch_size)
    #devices = d2l.try_all_gpus()
    devices = [0] 
    loss = nn.CrossEntropyLoss(reduction="none")
    
    if param_group:               #不是最后一层的拿出来,将最后一层设置为10倍学习率,因为他是随机初始的,希望训练快一点
        params_1x = [param for name,param in net.named_parameters()
                    if name not in ["fc.weight","fc.bias"]]
        trainer = torch.optim.SGD([{'params':params_1x},
                                  {'params':net.fc.parameters(),
                                   'lr':learning_rate* 10}],
                                 lr = learning_rate,weight_decay=0.001)
    else:
        trainer = torch.optim.SGD(net.parameters(),lr = learning_rate,
                                     weight_decay=0.001)
    d2l.train_ch13(net,train_iter,test_iter,loss,trainer,num_epochs,devices)
train_fine_tuning(finetune_net,5e-5)    #测试集精度高于训练集,overfiting不大
loss 0.346, train acc 0.884, test acc 0.934
726.8 examples/sec on [0]

对比实验,不使用与训练模型

scratch_net = torchvision.models.resnet18()
scratch_net.fc = nn.Linear(scratch_net.fc.in_features,2)
train_fine_tuning(scratch_net,5e-4,param_group=False)     #可以看到训练效果差了10个点
loss 0.344, train acc 0.851, test acc 0.859
725.2 examples/sec on [0]

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

瑾怀轩

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值