【迁移学习】深度域自适应网络DANN模型

本文介绍了深度域自适应网络DANN,用于解决源域和目标域数据分布不同但任务相同的迁移学习问题。DANN模型包含特征提取、图像分类和域分类三个部分,并利用梯度反转层实现域分类损失和分类预测损失的平衡。通过在MNIST和MNIST-M之间的迁移学习实验,展示了DANN的有效性。提供了相关实验的GitHub链接和所需运行环境及库。
部署运行你感兴趣的模型镜像

DANN

  • Domain-Adversarial Training of Neural Networks in Tensorflow
  • 域适配:目标域与源域的数据分布不同但任务相同下的迁移学习。

模型建立

  • DANN假设有两种数据分布:源域数据分布 S ( x , y ) \mathcal{S}(x,y) S(x,y)和目标域数据分布 T ( x , y ) \mathcal{T}(x,y) T(x,y);定义 d i d_i di为第 i i i个训练样本的域标签, d i ∈ { 0 , 1 } d_i\in\{0,1\} di{0,1},若 d i = 0 d_i=0 di=0,则 x i ∼ S ( x ) x_i\sim\mathcal{S}(x) xiS(x),反之 d i = 1 d_i=1 di=1 x i ∼ T ( x ) x_i\sim\mathcal{T}(x) xiT(x)

  • DANN网络的输入 x ∈ X x\in X xX X X X表示图像输入空间;图像分类标签 y ∈ Y y\in Y yY,其中 Y ( Y = { 1 , 2 , 3 , . . . , k } ) Y(Y=\{1,2,3,...,k\}) Y(Y={1,2,3,...,k})表示图像分类标签空间。在输入中,既有带标签的源域数据集也有不带标签的目标域数据集。

  • DANN的目标:准确预测目标域输入图像的分类标签

  • DANN网络构架:
    在这里插入图片描述
    绿色部分:特征提取网络;蓝色部分:图像分类网络;(绿蓝一起构成标准前馈结构)红色部分:域分类网络

    • 输入 x x x在训练阶段首先经过特征提取网络 f = G f ( x ; θ y ) f=G_f(x;\theta_y) f=Gf(x;θy)的映射转换为一个 D D D维特征向量。

      • 分支1:图像分类预测网络 G y ( x ; θ y ) G_y(x;\theta_y) Gy(x;θy)

        源域数据对应的特征向量通过预测网络获得对应分类标签预测结果。

      • 分支2:域分类网络 G d ( x ; θ d ) G_d(x;\theta_d) Gd(x;θd)

        源域和目标域输入对应的特征向量经过域分类网络得到每个输入的域分类结果。

  • 为了实现让DANN将目标域数据看成源域数据,在训练阶段需实现两个任务:

  1. 源域数据及准确分类,即分类误差最小化
  2. 混淆源域数据集和目标数据集,实现域分类误差最大化。

由于域分类器与图像分类器的输入都来自于特征提取网络,但分别要求分类损失最小化和最大化。

如果使用随机梯度下降法(SGD)来优化模型损失函数的,会导致特征提取器在进行参数更新时,域分类损失和分类预测损失梯度相反

  • DANN提出了新的梯度反转层(Gradient Reversal Layer,GRL),使得在反向传播的过程中梯度方向自动取反,在前向传播过程中只做恒等变换:

R λ ( x ) = x d R λ d x = − λ I \begin{gathered}R_\lambda(x)=x\\\frac{dR_{\lambda}}{dx}=-\lambda I\end{gathered} Rλ(x)=xdxdRλ=λI

  • 梯度反转层被插在特征提取层和与分类器之间,则在反向传播的过程中,域分类器的域分类损失的梯度反向传播到特征提取器之前会自动取反( λ \lambda λ随迭代次数动态变化)

论文:Unsupervised Domain Adaptation by Backpropagation

参考博客:【深度域适配】一、DANN与梯度反转层(GRL)详解

实验:MNIST域与MNIST-M域间的迁移学习

github传送门:https://github.com/pumpikano/tf-dann

运行环境

  • tensorflow >=1.0
  • Python 3.4

  • numpy
  • scikit-image
  • matplotlib
  • scikit-learn
  • jupyter
  • scipy

实验数据

  • MNIST作为源域,MNIST-M作为目标域。

    MNIST-M为将从BSDS的彩色照片中随机抽取的原始集合上的数字随机混合到MNIST的patch上——与原始数据及相比,背景和笔画不再恒定,域相当不同,但输出图片对人类来说仍看可区分出数字。

  • 样本均为28*28*3的图片,3代表RGB值

源域数据
源域数据与目标域数据
输入x:batch*28*28*3
两层CNN特征提取器
输出:batch*7*7*48
预测分类器
梯度反转层
域分类器
输出预测y
计算损失函数梯度
输出域分类结果d
计算损失函数梯度

您可能感兴趣的与本文相关的镜像

Stable-Diffusion-3.5

Stable-Diffusion-3.5

图片生成
Stable-Diffusion

Stable Diffusion 3.5 (SD 3.5) 是由 Stability AI 推出的新一代文本到图像生成模型,相比 3.0 版本,它提升了图像质量、运行速度和硬件效率

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值