Fisher矩阵相关理论研究

本文目的是针对《Overcoming catastrophic forgetting in neural networks》文中的EWC方法提到的Fisher矩阵进行相关知识调研和记录。

1. Fisher矩阵与自然梯度

首先针对防止灾难性遗忘的方法EWC中提到的Fisher矩阵,其引用的自然梯度下降方法如下:
参考论文《Revisiting natural gradient for deep networks》
考虑一个密度函数族 F \mathcal{F} F将参数 θ ∈ R P \theta \in \mathbb{R}^{P} θRP映射到概率密度函数 p ( z ) p(\bold{z}) p(z) p : R N → [ 0 , ∞ ) p:\mathbb{R}^N \rightarrow [0,\infin) p:RN[0,),其中 z ∈ R N z \in \mathbb{R}^N zRN θ ∈ R P \theta \in \mathbb{R}^{P} θRP的任何选择定义了一个特定的密度函数 p θ ( z ) = F ( θ ) ( z ) p_{\theta}(\bold{z})=\mathcal{F}(\theta)(\bold{z}) pθ(z)=F(θ)(z),并通过考虑所有可能的 θ \theta θ值,探索了集合函数流形 F \mathcal{F} F

在其无穷小的形式中,KL 散度的行为类似于距离度量,因此我们可以定义附近密度函数之间的相似性度量。因此 F \bold{F} F是一个黎曼流形,其度量由以下等式中定义的 Fisher 信息矩阵 F θ \bold{F}_{\theta} Fθ给出:
F θ = E z [ ( ∇ log ⁡ p θ ( z ) ) T ( ∇ log ⁡ p θ ( z ) ) ] (1) \bold{F}_{\theta}=\mathbb{E}_z[(\nabla \log p_{\theta}(\bold{z}))^T(\nabla \log p_{\theta}(\bold{z}))] \tag{1} Fθ=Ez[(logpθ(z))T(logpθ(z))](1)
也就是说,在某个点 θ \theta θ周围局部,该度量定义了向量 u u u v v v之间的内积: < u , v > θ = u F θ v (2) <u,v>_{\theta}=u\bold{F}_{\theta}v \tag{2} <u,v>θ=uFθv(2)
因此,它提供了距离的局部度量。假设该矩阵对 θ \theta θ的隐式依赖性,接下来为 Fisher 信息矩阵编写 F \bold{F} F
给定由 θ \theta θ参数化的损失函数 L \mathcal{L} L,自然梯度下降试图通过根据KL发散面的局部曲率校正L的梯度来沿着流形移动,即在方向上移动给定的距离 ∇ N L ( θ ) \nabla_N \mathcal{L}(\theta) NL(θ):
∇ N L ( θ ) = d e f ∇ L ( θ ) E z [ ( ∇ log ⁡ p θ ( z ) ) T ( ∇ log ⁡ p θ ( z ) ) ] − 1 = d e f ∇ L ( θ ) F − 1 \begin{align} \nabla_N \mathcal{L}(\theta) &\overset{def}{=} \nabla \mathcal{L}(\theta) \mathbb{E}_z[(\nabla \log p_{\theta}(\bold{z}))^T(\nabla \log p_{\theta}(\bold{z}))]^{-1} \tag{3}\\ &\overset{def}{=} \nabla \mathcal{L}(\theta) \bold{F}^{-1} \tag{4} \end{align} NL(θ)=defL(θ)Ez[(logpθ(z))T(logpθ(z))]1=defL(θ)F1(3)(4)
对自然梯度使用 ∇ N \nabla_N N,对梯度使用 ∇ \nabla F \bold{F} F是Fisher信息矩阵给出的度量矩阵。在这项工作中,偏导数通常表示为行向量。我们可以通过将自然梯度下降定义为算法来导出这一结果,该算法在每一步都试图选择下降方向,使我们的模型中引起的变化量(在KL意义上)为某个给定值。特别地,当 p θ p_θ pθ p θ + ∆ θ p_{θ+∆θ} pθ+θ之间的KL散度的二阶泰勒级数必须是常数时,我们寻找一个小的 ∆ θ ∆θ θ,它最小化 L \mathcal{L} L的一阶泰勒展开:
arg min ⁡ ∆ θ L ( θ + ∆ θ ) s . t . K L ( p θ ∣ ∣ p θ + ∆ θ ) = c o n s t . \begin{align} & \argmin_{∆θ} \mathcal{L}(θ+∆θ) \notag\\ &\mathrm{ s.t. } KL(p_{\theta} || p_{θ+∆θ}) =const. \tag{5} \end{align} θargminL(θ+θ)s.t.KL(pθ∣∣pθ+θ)=const.(5)
使用此约束,确保以恒定速度沿函数流形移动,而不会因其曲率而减慢速度。这也使得学习对模型的重新参数化具有局部鲁棒性,因为 p p p的函数行为不取决于它是如何参数化的。

假设 ∆ θ → 0 ∆θ→ 0 θ0,我们可以通过其二阶泰勒级数来近似KL散度:
K L ( p θ ∣ ∣ p θ + ∆ θ ) ≈ ( E z [ log ⁡ p θ ] − E z [ log ⁡ p θ ] ) − E z [ ∇ log ⁡ p θ ( z ) ] ∆ θ − 1 2 ∆ θ T E z [ ∇ 2 log ⁡ p θ ] ∆ θ = 1 2 ∆ θ T E z [ − ∇ 2 log ⁡ p θ ] ∆ θ = 1 2 ∆ θ T F ∆ θ \begin{align} KL(p_{\theta} || p_{θ+∆θ}) &\approx (\mathbb{E}_z[\log p_{\theta}]- \mathbb{E}_z[\log p_{\theta}]) \notag \\ &-\mathbb{E}_z[\nabla \log p_{\theta}(\bold{z})]∆θ-\frac{1}{2}∆θ^T\mathbb{E}_z[\nabla^2 \log p_{\theta}] ∆θ \notag\\ &= \frac{1}{2} ∆θ^T \mathbb{E}_z[-\nabla^2 \log p_{\theta}]∆θ \notag \\ &=\frac{1}{2}∆θ^T \bold{F} ∆θ \tag{6} \end{align} KL(pθ∣∣pθ+θ)(Ez[logpθ]Ez[logpθ])Ez[logpθ(z)]θ21θTEz[2logpθ]θ=21θTEz[2logpθ]θ=21θTFθ(6)
第一项抵消,由于 E z [ ∇ log ⁡ p θ ( z ) ] = ∑ z ( p θ ( z ) 1 p θ ( z ) ∂ p θ ( z ) ∂ θ ) = ∂ ∂ θ ( ∑ θ p θ ( z ) ) = ∂ 1 ∂ θ = 0 \mathbb{E}_z[\nabla \log p_{\theta}(\bold{z})]= \sum_{\bold{z}}(p_{\theta}(\bold{z}){\frac{1}{p_\theta(\bold{z})} \frac{\partial p_\theta(\bold{z})}{\partial \theta}})=\frac{\partial}{\partial \theta}(\sum_{\theta} p_{\theta}(\bold{z}))=\frac{\partial 1}{\partial \theta}=0 Ez[logpθ(z)]=z(pθ(z)pθ(z)1θpθ(z))=θ(θpθ(z))=θ1=0,只保留最后一项。另一方面Fisher信息矩阵形式可以通过代数运算从Hessian的期望值获得。

现在将等式 (5) 表示为拉格朗日函数,其中 KL 散度由 (6) 和 L ( θ + ∆ θ ) \mathcal{L}(θ+∆θ) L(θ+θ)通过其一阶泰勒级数 L ( θ ) + ∇ L ( θ ) ∆ θ \mathcal{L}(θ)+\nabla \mathcal{L}(θ)∆θ L(θ)+L(θ)θ近似:
L ( θ ) + ∇ L ( θ ) ∆ θ + 1 2 λ ∆ θ T F ∆ θ (7) \mathcal{L}(θ)+\nabla \mathcal{L}(θ)∆θ+\frac{1}{2}\lambda ∆θ^T \bold{F} ∆θ \tag{7} L(θ)+L(θ)θ+21λθTFθ(7)
求解 ∆ θ ∆θ θ的方程(7),得到自然梯度下降公式(4)。得到自然梯度的 2 1 λ 2\frac{1}{\lambda} 2λ1倍的标量因子。我们将此标量折叠到学习率中,现在还控制我们在保持 p θ p_θ pθ p θ + Δ θ p_{θ+Δθ} pθ+Δθ 之间的 KL 距离的权重。我们使用的近似值仅在 θ θ θ左右有意义:在 Schul(2012) 研究中,表明采取大步骤可能会损害收敛。通过使用阻尼(即设置 θ θ θ周围的信任区域)和正确选择学习率来处理此类问题。

2. Fisher矩阵分析

提示:这里的分析结合了“通义千问”的回答
2.1 自然梯度的几何解释:
在优化过程中,我们希望找到使目标函数 L ( θ ) \mathcal{L}(θ) L(θ) 最小化的参数 θ θ θ。标准梯度下降是在欧几里得空间中工作,沿梯度 ∇ θ L ( θ ) \nabla_\theta \mathcal{L}(θ) θL(θ)的方向是最陡峭下降的方向。然而,在处理概率分布参数时,参数空间往往具有非欧几里得几何结构。Fisher矩阵可以用来定义参数空间上的一个Riemannian度量,即参数间的“自然”距离。在这种几何框架下,自然梯度 G n a t ( θ ) G_{nat}(θ) Gnat(θ) 是在这个参数空间中沿着目标函数最陡峭下降方向的向量。

2.2 Fisher矩阵的物理意义:
Fisher矩阵 I ( θ ) I(θ) I(θ) 可以看作是参数 θ θ θ在给定概率模型 p ( x ∣ θ ) p(x∣θ) p(xθ)下的信息含量。它量化了当观测数据 x x x 固定时,参数 θ θ θ的微小变动对对数似然函数 log ⁡ p ( x ∣ θ ) \log p(x∣θ) logp(xθ)的影响。Fisher矩阵的每个元素 I i j ( θ ) I_{ij}(θ) Iij(θ)表示参数 θ i θ_i θi θ j θ_j θj之间的二阶偏导数期望,即参数间变化的局部相关性。Fisher矩阵是对称的,且非负定,其逆矩阵 I ( θ ) − 1 I(θ)^{-1} I(θ)1描述了参数估计的协方差矩阵的逆,反映了参数估计的不确定性。

2.3 Cramér-Rao Bound的关联:
Fisher信息矩阵与参数估计的精度有着直接联系。Cramér-Rao Bound (CRB) 是一个统计学的基本定理,它指出在无偏估计的情况下,任何估计量的协方差矩阵的逆(即精度矩阵)至少要等于Fisher信息矩阵。换句话说,Fisher信息矩阵的逆给出了参数估计误差协方差的下界。因此,参数的Fisher信息量越大,其估计的不确定性理论上就越小,即该参数越重要,我们能更精确地估计它。

2.4 Fisher矩阵与KL散度的关系:
自然梯度与Fisher信息矩阵之间的紧密联系源于它们与Kullback-Leibler (KL) 散度的关系。KL散度是衡量两个概率分布 p p p q q q之间差异的一个常用指标。在自然梯度的背景下,我们关心的是参数 θ θ θ的微小变化如何影响模型分布 p ( x ∣ θ ) p(x∣θ) p(xθ)。当 θ θ θ改变为 θ + Δ θ θ+Δθ θ+Δθ时,KL散度 D K L ( p ( x ∣ θ ) ∣ ∣ p ( x ∣ θ + Δ θ ) ) D_{KL}(p(x∣θ)∣∣p(x∣θ+Δθ)) DKL(p(xθ)∣∣p(xθ+Δθ))可以展开为泰勒级数,其中二阶项与Fisher矩阵相关:
D K L ( p ( x ∣ θ ) ∣ ∣ p ( x ∣ θ + Δ θ ) ) ≈ 1 2 ∆ θ T F ∆ θ D_{KL}(p(x∣θ)∣∣p(x∣θ+Δθ)) \approx \frac{1}{2}∆θ^T \bold{F} ∆θ DKL(p(xθ)∣∣p(xθ+Δθ))21θTFθ
这意味着Fisher矩阵描述了参数变化 Δ θ Δθ Δθ对KL散度的二次贡献。在优化过程中,我们希望最小化KL散度(或者等价地,最大化对数似然),自然梯度的方向正是使得KL散度相对于参数变化 Δ θ Δθ Δθ最陡峭下降的方向。

3. EWC中的Fisher矩阵

EWC防止灾难性遗忘方法的训练目标如下:
L ( θ ) = L B ( θ ) + ∑ i λ 2 F i ( θ i − θ A , i ∗ ) 2 \mathcal{L}(\theta)=\mathcal{L}_B(\theta)+\sum_i \frac{\lambda}{2}F_i(\theta_i-\theta^*_{A,i})^2 L(θ)=LB(θ)+i2λFi(θiθA,i)2
但是面临以下几个问题:
3.1 为什么用fisher矩阵,不用一阶导:

一阶导数仅提供局部梯度信息: Score函数给出的是似然函数在参数空间中的梯度,它指示了似然函数增大的方向,用于找到最大似然估计。然而,梯度仅描述了参数空间中某一点的线性趋势,无法捕捉非线性效应和参数间的相互依赖关系,这些非线性效应和依赖关系对模型的整体行为至关重要。
二阶矩反映曲率和局部稳定性: Fisher信息矩阵包含了对数似然函数的二阶偏导数,它刻画了似然函数在参数空间中的曲率。曲率反映了模型对参数微小变动的响应程度,即参数敏感性。高的曲率意味着参数稍有变化就会导致似然函数显著变化,这对应于参数对模型输出具有高度影响力。此外,曲率还与参数估计的局部稳定性相关,即在估计过程中参数是否容易受到噪声或数据扰动的影响。
统计推断的必要性: 在统计推断中,我们通常关心的是参数的置信区间和估计的精度,这些都与参数估计的方差(或协方差)紧密相关。一阶导数无法直接提供这些信息,而Fisher信息矩阵的逆恰好给出了参数估计误差协方差的下界(Cramér-Rao Bound),这与实际的推断需求直接对应。

3.2 为什么EWC是乘上参数之间的差:
EWC正则化项中的差值 ( θ i − θ i ∗ ) (\theta_i -\theta_i^*) (θiθi)表示当前参数与旧任务最优参数之间的差异,平方后乘以Fisher信息量 F i F_i Fi,用来衡量这种差异对于旧任务的影响。大的Fisher信息量意味着参数对于旧任务非常关键,因此当该参数偏离旧任务最优值时,正则化项会施加较大的惩罚,限制其在学习新任务时的变动。

3.3 Fisher矩阵的代码如何实现(只需要计算对角线即可):

def getFisherDiagonal(self, train_loader):
        fisher = {
            n: torch.zeros(p.shape).to(self._device)
            for n, p in self._network.named_parameters()
            if p.requires_grad
        }
        self._network.train()
        optimizer = optim.SGD(self._network.parameters(), lr=lrate)
        for i, (_, inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(self._device), targets.to(self._device)
            logits = self._network(inputs)["logits"]
            loss = torch.nn.functional.cross_entropy(logits, targets)
            optimizer.zero_grad()
            loss.backward()
            for n, p in self._network.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.pow(2).clone()
        for n, p in fisher.items():
            fisher[n] = p / len(train_loader)
            fisher[n] = torch.min(fisher[n], torch.tensor(fishermax))
        return fisher

小结

本部分内容主要总结了Fisher矩阵的理论知识,为下一步的相关工作进行知识铺垫。

  • 20
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值