【Pytorch】迁移学习(Transfer Learning)

1.什么是迁移学习

迁移学习是一种机器学习技术,其中一个模型已在一个任务上训练好,并且该模型的经验可以用来更快地训练另一个相似的任务。这种技术的目的是为了减少在新任务上的训练时间,因为训练模型需要大量的数据和时间。

2.常见迁移学习方式

  1. 载入权重后训练所有参数
  2. 载入权重后只训练最后几层参数
  3. 载入权重后在原网络基础上再添加一层全链接层,仅训练最后一个全链接层

3.例子

以kaggle中猫狗数据集为例,猫狗数据集
导包

import torch
import torchvision
from torch import nn, optim
from torchvision import transforms, datasets, models
from tqdm import tqdm
import sys

tqdm和sys,如果不需要进度条显示可不导入。
pytorch中迁移学习模型在torchvision.models

  1. vgg16
model = models.vgg16(pretrained=True)  # pretrained=True即为返回在 ImageNet (是数据集)上预训练的模型
for parameter in model.parameters():
    parameter.requires_grad = False	   # 冻结了所有层(参数不会更新)

此时模型已导入。可用model.buffer或直接print(model)查看模型。
模型结构
猫狗数据集为二分类,所以最后一层全连接层输出应为2,修改为:

model.classifier[6] = nn.Linear(in_features=4096, out_features=2, bias=True)

修改过后该全连接层不再被冻结,参数可被更新。
可用以下方式查看模型各层的冻结情况:

for m, n in model.named_parameters():
    print(m, n.requires_grad)

查看各层冻结情况
最后就是训练,训练最后一层全连接层的参数。

optimizer = optim.Adam(model.classifier.parameters(), lr=0.0001)

save_path = './save_path/vgg16_1.pth'

epochs = 15
train_steps = len(train_dataloader)
val_length = len(val_dataset)
best_acc = 0.0

for epoch in range(epochs):
    running_loss = 0.0
    model.train()
    train_bar = tqdm(train_dataloader, file=sys.stdout)
    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        output = model(images.to(device))
        loss = loss_function(output, labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss
        train_bar.desc = 'epoch:{}/{} loss:{:.3f}'.format(epoch+1, epochs, loss)
        
    model.eval()
    acc = 0.0
    with torch.no_grad():
        val_bar = tqdm(val_dataloader, file=sys.stdout)
        for data in val_bar:
            val_images, val_labels = data
            val_output = model(val_images.to(device))
            predict = torch.max(val_output, 1)[1]
            acc += torch.eq(predict, val_labels.to(device)).sum().item()
    
        val_accurate = acc/val_length
        print('epoch:{}/{} train_loss:{:.3f} val_accurate:{:.3f}'.format(epoch+1, epochs, running_loss/train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_path)

一共训练了15个epoch,一个epoch差不多43s,训练过程:
训练过程
可以看到验证集准确率差不多95%。

还可参考这篇文章

  • 2
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值