基于DANN的图像分类任务迁移学习

本文探讨了如何使用Domain-Adversarial Training of Neural Networks (DANN)进行迁移学习,特别是在图像分类任务中。通过数据预处理,使不同分布的图像变得相似,然后利用DANN算法训练模型,使得特征提取层能适应源数据和目标数据分布,从而提升无标签数据的分类准确率。实验结果显示,DANN有助于改善模型在目标分布数据上的表现。
摘要由CSDN通过智能技术生成

注:本博客的数据和任务来自NTU-ML2020作业,Kaggle网址为Kaggle.

数据预处理

我们要进行迁移学习的对象是10000张32x32x3的有标签正常照片,共有10类,和另外100000张人类画的手绘图,28x28x1黑白照片,类别也是10类但无标签。我们希望做到,让模型从有标签的原始分布数据中学到的知识能应用于无标签的,相似但与原始分布不相同的目标分布中,并提高黑白手绘图的正确率。
为此,训练前还要对数据做预处理。首先让原始分布的图像和目标分布的图像尽可能相似,我们要做有色图转灰度图,然后做边缘检测。为了模型的输入维度相同,要把28x28转为32x32.此外还可以增加一些平移旋转来让学习更鲁棒。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import cv2
import matplotlib.pyplot as plt

# 在transform中使用转灰度-canny边缘提取-水平移动-小幅度旋转-转张量操作

source_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Lambda(lambda x: cv2.Canny(np.array(x), 170, 300)),
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15, fill=(0,)),
    transforms.ToTensor(),
])
target_transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15, fill=(0,)),
    transforms.ToTensor(),
])

# 读取数据集,分为source和target两部分

source_dataset = ImageFolder('E:/real_or_drawing/train_data', transform=source_transform)
target_dataset = ImageFolder('E:/real_or_drawing/test_data', transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

DANN

Domain-Adversarial Training of NNs,值域对抗学习。这种算法是我们这里将要用的迁移学习方法,它被提出的起因是让CNN能够同时用于不同分布的数据,如果模型直接接收原值域的数据分布进行训练,即使原分布和目标分布有类似的地方,在接收目标值域的数据时,也会出现相当异常的特征提取和分类结果。我们可以理解为是模型在源数据分布上出现了过拟合(并不是对数据的过拟合),在接收一些没有见到过的数据时自然会表现不佳。
在这里插入图片描述
解决这个问题最好的办法就是让模型在训练时也接收目标数据分布的数据。但是目标数据分布是无标签的,我们要用什么标准来训练模型呢?回忆CNN的架构,CNN使用卷积-池化的特征提取层来提取图片特征,后接全连接层进行预测。我们只需要让特征提取层既能提取原数据分布的特征,又能提取目标数据分布的特征,这样全连接层就能对两种值域但具有相同特征的数据进行同样的分类,从而目标数据分布的输入也很有可能被正确分类。
在这里插入图片描述
那么问题就变成了如何训练输入两个不同分布的数据,输出却是同种分布的特征提取层。回忆GAN的架构,我们让分布朝着源数据分布发展的方法是建立判别器,让判别器能分辨两种数据,而让生成器改变参数骗过判别器。这里也可以用同样的思想,我们建立能分辨原始分布和目标分布的二分类判别器,把特征提取层和二分类判别层接在一起。首先训练判别器,让判别器能分辨两类数据分布。然后训练特征提取层,逆梯度更新让特征提取层生成能骗过判别器的数据(目标输出0.5).如此训练多次直到特征提取层能把两种值域的输入变成同种分布的输出。
在这里插入图片描述
但是只是用GAN方法train特征提取层并不明智,因为我们的目标输出只有0-1的二分类,训练很有可能只是让特征提取层提取到一些没有用的特征。因此我们要一边训练正常的标签预测任务,一边训练判别器的判别任务和混淆两类输入的任务。这可能需要自己定义特殊的loss function

最后,我们就获得了能同时提取两个值域的特征的特征提取层,它后面的多分类层就可以对目标分布的数据做出还算称心如意的预测。

模型、训练、测试代码

这里使用类VGG(用多个3x3的卷积核代替大型卷积核以节约参数)的搭建方式,写一个高度卷积的特征提取层

class 
DANN (Domain Adversarial Neural Network) 是一种用于迁移学习的方法,它可以在源领域和目标领域之间进行知识迁移。DANN的目标是通过最小化源领域和目标领域之间的领域差异,来实现在目标领域上的良好泛化性能。 DANN的代码实现可以使用深度学习框架如PyTorch或TensorFlow来完成。以下是一个简单的DANN代码示例: ```python import torch import torch.nn as nn import torch.optim as optim # 定义源领域和目标领域的数据加载器 source_loader = ... target_loader = ... # 定义DANN模型 class DANNModel(nn.Module): def __init__(self): super(DANNModel, self).__init__() # 定义共享特征提取器 self.feature_extractor = nn.Sequential( nn.Conv2d(3, 64, kernel_size=5), nn.ReLU(), ... ) # 定义分类器 self.classifier = nn.Sequential( nn.Linear(64 * 5 * 5, 100), nn.ReLU(), nn.Linear(100, 10), nn.LogSoftmax(dim=1) ) # 定义领域分类器 self.domain_classifier = nn.Sequential( nn.Linear(64 * 5 * 5, 100), nn.ReLU(), nn.Linear(100, 2), nn.LogSoftmax(dim=1) ) def forward(self, x): features = self.feature_extractor(x) features = features.view(features.size(0), -1) class_output = self.classifier(features) domain_output = self.domain_classifier(features) return class_output, domain_output # 初始化模型和优化器 model = DANNModel() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 定义损失函数 classification_loss = nn.NLLLoss() domain_loss = nn.NLLLoss() # 训练DANN模型 for epoch in range(num_epochs): for source_data, target_data in zip(source_loader, target_loader): # 将源数据和目标数据输入模型 source_inputs, source_labels = source_data target_inputs, _ = target_data source_class_output, source_domain_output = model(source_inputs) target_class_output, target_domain_output = model(target_inputs) # 计算分类损失和领域损失 class_loss = classification_loss(source_class_output, source_labels) domain_loss = domain_loss(source_domain_output, torch.zeros(source_domain_output.size(0))) domain_loss += domain_loss(target_domain_output, torch.ones(target_domain_output.size(0))) # 总损失为分类损失加上领域损失 total_loss = class_loss + domain_loss # 反向传播和优化 optimizer.zero_grad() total_loss.backward() optimizer.step() # 进行预测 target_inputs, _ = target_loader target_class_output, _ = model(target_inputs) predictions = torch.argmax(target_class_output, dim=1) ``` 这只是一个简单的DANN代码示例,实际应用中可能需要根据具体任务和数据进行适当的修改和调整。希望对你有所帮助!
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值