[NIPS 2018] Generalized Zero-Shot Learning with Deep Calibration Network

基于深度校准网络的广义零样本学习
Generalized Zero-Shot Learning with Deep Calibration Network

本文亮点:在训练时使用目标域的标签

文章目录


1 Introduction 引言

在这里插入图片描述
动机/问题:广义零样本学习的技术难点。对已见类数据的过拟合导致对目标类别(已见类和未见类)的不确定预测,是GZSL性能低的原因。

如图,经过校正的网络预测更加准确。
问题:预测更加准确,是否能够提升分类精度?能够在实验中给出前后对比吗?


3 广义零样本学习

符号:
已见数据 D = { ( x n , y ) n ) } n = 1 N \mathcal{D} = \{ (x_n, y)n) \}_{n=1}^N D={(xn,y)n)}n=1N
源类别 S = { 1 , ⋯   , S } \mathcal{S}=\{ 1, \cdots, S \} S={1,,S}
目标类 T = { S + 1 , ⋯   , S + T } \mathcal{T}=\{ S+1, \cdots, S+T \} T={S+1,,S+T}, 训练时样本不可见
一个类别 c ∈ { S ∪ T } c \in \{\mathcal{S \cup T}\} c{ST}的语义表示为 a c ∈ R Q a_c \in \mathbb R^Q acRQ
所有类别的语义表示 A = { a c } c = 1 S + T \mathcal{A}=\{a_c\}_{c=1}^{S+T} A={ac}c=1S+T
未见类数据 D ′ = { x m } m = N + 1 N + M \mathcal{D'} = \{ x_m \}_{m=N+1}^{N+M} D={xm}m=N+1N+M, 源类或者目标类数据

定义1:零样本,ZSL Given D \mathcal{D} D and { a c } c = 1 S \{a_c\}_{c=1}^{S} {ac}c=1S, classify D \mathcal{D} D over target classes T \mathcal{T} T.
定义2:广义零样本,GZSL Given D \mathcal{D} D and { a c } c = 1 S + T \{a_c\}_{c=1}^{S+T} {ac}c=1S+T of both source and target classes, learn a model f : x ↦ y f: x \mapsto y f:xy to classify D ′ \mathcal{D'} D over both source and target classes S ∪ T \mathcal{S \cup T} ST.

在这个定义里,ZSL没有利用目标域的标签。

3.1 预测函数

图像 x ∈ D x \in \mathcal{D} xD
特征嵌入 ϕ ( x ) ∈ R K \phi(x) \in \mathbb R^K ϕ(x)RK
类别语义 a ∈ A a \in \mathcal{A} aA,属性或者词向量
语义嵌入 ψ ( a ) ∈ R K \psi(a) \in \mathbb R^K ψ(a)RK

这里的嵌入空间就是特征空间,论文给出的是2048维的ResNet特征或者1024维的GoogleNet特征

图像的视觉嵌入 z n = ϕ ( x n ) z_n = \phi(x_n) zn=ϕ(xn)
类别的语义嵌入 v c = ψ ( a c ) v_c = \psi(a_c) vc=ψ(ac)

预测函数
f c ( x n ) = s i m ( ϕ ( x n ) , ψ ( a c ) ) f_c(x_n) = \rm sim(\phi(x_n), \psi(a_c)) fc(xn)=sim(ϕ(xn),ψ(ac))
s i m ( . , . ) \rm sim(., .) sim(.,.)是相似度函数,比如內积和余弦相似度; f c ( x n ) f_c(x_n) fc(xn)是(nearest prototype classifier) NPC分类器分配给图像 x n x_n xn类别 c c c的强度。

图像 x n x_n xn的预测类别 y ( x n ) y(x_n) y(xn)
y ( x n ) = arg ⁡ max ⁡ c f c ( x n ) y(x_n)=\arg \max_c f_c(x_n) y(xn)=argcmaxfc(xn)

论文提到,预测源类和目标类的导致的技术难度是不一样的。

3.2 风险最小化

multi-class Hinge loss
∑ n = 1 N ∑ c = 1 S = max ⁡ ( 0 , Δ ( y n , c ) + f c ( x n ) − f y n ( x n ) ) \sum_{n=1}^{N}\sum_{c=1}^{S}=\max (0, \Delta(y_n, c) + f_c(x_n)-f_{y_n}(x_n) ) n=1Nc=1S=max(0,Δ(yn,c)+fc(xn)fyn(xn))
其中,间隔定义为
Δ ( y n , c ) = { 0 y n = c 1 y n ! = c \Delta(y_n, c) = \begin{cases} 0& {y_n = c}\\ 1& {y_n != c} \end{cases} Δ(yn,c)={01yn=cyn!=c
文中提到大部分零样本学习方法使用多分类Hinge损失来学习视觉语义映射。

作者应用温度校正来缓解由于在已见数据上的过拟合导致的对源域类别的过分相信。温度校正是Hinton老爷子提出来从深度网络蒸馏知识的。作者应用温度校正来将预测 f f f转换到源于类别上的概率分布

p c ( x n ) = exp ⁡ ( f c ( x n ) / τ ) ∑ c ′ = 1 S exp ⁡ ( f c ′ ( x n ) / τ ) p_c(x_n) = \frac {\exp(f_c(x_n)/\tau)} {\sum_{c'=1}^{S} \exp(f_{c'}(x_n)/\tau)} pc(xn)=c=1Sexp(fc(xn)/τ)exp(fc(xn)/τ)

其中, τ \tau τ就是温度,当 τ = 1 \tau=1 τ=1是深度网络里最常见的选项。温度 τ \tau τ τ > 1 \tau>1 τ>1“软化”了softmax。当 τ → ∞ \tau \to \infty τ时,概率 p c → 1 / S p_c \to 1/S pc1/S,这将导致最大的不确定性。当 τ → 0 \tau \to 0 τ0时,概率坍缩到一点(即 p c = 1 p_c = 1 pc=1)。因为 τ \tau τ不改变softmax函数的最大值,收敛后如果应用温度校正 τ ≠ 1 \tau \neq 1 τ=1

将概率 p c p_c pc插入到源域类别 S S S的可见数据 D \mathcal D D上的交叉熵损失得到

L = − ∑ n = 1 N ∑ c = 1 S y n , c log ⁡ p c ( x n ) . (6) L = -\sum_{n=1}^{N} \sum_{c=1}^{S} y_{n, c} \log{p_c(x_n)}. \tag{6} L=n=1Nc=1Syn,clogpc(xn).(6)

关于这个loss,作者认为,相比于multi-class Hinge loss,虽然交叉熵是一个很简单的处理多分类的方案,但能够利用温度校正来缓解过拟合。

3.3 不确定性校准

不管是ZSL还是GZSL,都强调了模型训练不能使用目标域训练数据。但是,要用模型识别目标域的数据,必须让模型学习目标域的知识。所以就只能用到目标域的语义信息。

作者提出,将模型的预测 f c f_c fc转换成目标域上的概率(带有温度校正)。

q c ( x n ) = exp ⁡ ( f c ( x n ) / τ ) ∑ c ′ = S + 1 S + T exp ⁡ ( f c ′ ( x n ) / τ ) (7) q_c(x_n) = \frac {\exp (f_c(x_n)/\tau)} {\sum_{c'=S+1}^{S+T} \exp(f_{c'}(x_n)/\tau)} \tag{7} qc(xn)=c=S+1S+Texp(fc(xn)/τ)exp(fc(xn)/τ)(7)

温度校正 τ ≠ 1 \tau \neq 1 τ=1在公式(6)和(7)的端到端的训练中都会用到。

解释
直观上讲,目标域 c c c和源域图片 x n x_n xn对应的源域越相似,概率 q c ( x n ) q_c(x_n) qc(xn)的值越大。这样就避免了训练时源域图像对目标域图像的不确定性一致。在信息论中,熵 h ( q ) = − q log ⁡ q h(q)=-q\log{q} h(q)=qlogq是对分布 q q q的不确定性的度量。值越低,不确定性越小。在本文中,作者提出了基于熵准则的不确定性校正的目标函数:

H = − ∑ n = 1 N ∑ c = S + 1 S + T q c ( x n ) log ⁡ q c ( x n ) . (8) H = -\sum_{n=1}^{N} \sum_{c=S+1}^{S+T} q_c(x_n) \log{q_c(x_n)}. \tag{8} H=n=1Nc=S+1S+Tqc(xn)logqc(xn).(8)

需要实验去看看,这个效果怎么样

3.4 深度校准网络

优化目标如下:
min ⁡ L + λ H + γ Ω ( ϕ , ψ ) , (9) \min { L + \lambda{H} + \gamma{\Omega (\phi, \psi)} }, \tag{9} minL+λH+γΩ(ϕ,ψ),(9)

Ω ( ϕ , ψ ) \Omega (\phi, \psi) Ω(ϕ,ψ)是模型复杂度惩罚项。在深度学习中,可以用权值衰减来替代它。


4 实验

4.2 Standard ZSL

在这里插入图片描述

在这里插入图片描述
τ \tau τ越小,不确定性越大。

4.3 GZSL结果

在这里插入图片描述
最后三行结果表明,不确定性校正的高效。

GZSL的精度比ZSL低很多,为什么?

  1. 源域的精度低,是为什么?
  2. 目标域精度低,是为什么?模型对源域过拟合。
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值