[持续学习] Fisher信息矩阵与EWC

一、前置知识

1. 得分函数 score / informant

score / informant 定义为对数似然函数关于参数的梯度:

s ( θ ) ≡ ∂ log ⁡ L ( θ ) ∂ θ s(\theta) \equiv \frac{\partial\log{\mathcal{L}(\theta)}}{\partial\theta} s(θ)θlogL(θ)

其中 L ( θ ) \mathcal{L}(\theta) L(θ)即为似然函数,可扩写为 L ( θ ∣ x ) \mathcal{L}(\theta |x) L(θx),其中 x x x为观测到的数据, x x x从采样域 X \mathcal{X} X中产生

在某一特定点的 s s s 函数指明了该点处对数似然函数的陡峭程度(steepness),或者是函数值对参数发生无穷小量变化的敏感性。

如果对数似然函数定义在连续实数值的参数空间中,那么它的函数值将在(局部)极大值与极小值点消失。这一性质通常用于极大似然估计中(maximum likelihood estimation, MLE),来寻找使得似然函数值极大的参数值。


注意 L ( θ ∣ x ) \mathcal{L}(\theta |x) L(θx)中竖线前后的字母 θ ∣ x \theta|x θx x x x为随机变量,在这里则是一个定值,意为采样后的观测值,而 θ \theta θ则为自变量,意为参数模型中的参数

当(假设 θ \theta θ位于正确值时,我们可以通过 θ \theta θ推导 x x x,也就是 f ( x ∣ θ ) f(x|\theta) f(xθ) ,为一概率密度函数,意为当模型参数为 θ \theta θ时,采样到 x x x的概率

从两个角度得到了对同一事实的论证,因此可写作 f ( x ∣ θ ) = L ( θ ∣ x ) f(x|\theta) = \mathcal{L}(\theta | x) f(xθ)=L(θx)


首先,来分析 s s s的数学期望,这里讨论的问题是:当参数取值为 θ \theta θ时, s ∣ θ s|\theta sθ的数学期望

从直观上分析,当参数位于真实最佳)参数点时,似然函数有其极大值(考虑极大似然估计的定义),因此为一极值点,所以该点梯度为 0 0 0,即 E [ s ∣ θ ] = 0 \mathbb{E}[s|\theta]= 0 E[sθ]=0

下面进行公式分析:

首先要明确,该期望是 s s s函数关于什么随机变量的期望。从上面的讨论中可以得到,该问题中唯一的随机变量是采样观测值 x x x,它的采样概率是 f ( x ∣ θ ) f(x|\theta) f(xθ)

注意:

f ( x ) ∂ log ⁡ f ( x ) ∂ x = f ( x ) 1 f ( x ) ∂ f ( x ) ∂ x = ∂ f ( x ) ∂ x \begin{aligned} & f(x) \frac{\partial\log{f(x)}}{\partial{x}} \\ = & {f(x)} \frac{1}{f(x)} \frac{\partial{f(x)}}{\partial{x}} \\ = & \frac{\partial{f(x)}}{\partial{x}} \end{aligned} ==f(x)xlogf(x)f(x)f(x)1xf(x)xf(x)

所以:

E [ s ∣ θ ] = ∫ X f ( x ∣ θ ) ⋅ s ⋅ d x = ∫ X f ( x ∣ θ ) ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x = ∫ X f ( x ∣ θ ) ∂ log ⁡ f ( x ∣ θ ) ∂ θ d x = ∫ X ∂ f ( x ∣ θ ) ∂ θ d x = ∂ ∂ x ∫ X f ( x ∣ θ ) d x = ∂ ∂ x 1 = 0 ■ \begin{aligned} \mathbb{E}[s|\theta] & = \int_{\mathcal{X}}f(x|\theta)\cdot{}s\cdot{}\mathrm{d}x \\ & = \int_{\mathcal{X}}f(x|\theta) \frac{\partial\log{\mathcal{L}(\theta|x)}}{\partial\theta} \mathrm{d}x \\ & = \int_{\mathcal{X}}f(x|\theta) \frac{\partial\log{f(x|\theta)}}{\partial\theta} \mathrm{d}x \\ & = \int_{\mathcal{X}} \frac{\partial f(x|\theta)}{\partial{\theta}} \mathrm{d}x \\ & = \frac{\partial}{\partial{x}} \int_{\mathcal{X}}f(x|\theta)\mathrm{d}x \\ & = \frac{\partial}{\partial{x}} 1 \\ & = 0\qquad\blacksquare \\ \end{aligned} E[sθ]=Xf(xθ)sdx=Xf(xθ)θlogL(θx)dx=Xf(xθ)θlogf(xθ)dx=Xθf(xθ)dx=xXf(xθ)dx=x1=0

因此得证: E [ s ∣ θ ] = 0 \mathbb{E}[s|\theta]= 0 E[sθ]=0

2. Fisher信息矩阵

Fisher信息(Fisher information),或简称为信息(information)是一种衡量信息量的指标

假设我们想要建模一个随机变量 x x x 的分布,用于建模的参数是 θ \theta θ,那么Fisher信息测量了 x x x 携带的对于 θ \theta θ 的信息量

所以,当我们固定 θ \theta θ 值,以 x x x 为自变量,Fisher 信息应当指出这一 x x x 值可贡献给 θ \theta θ 多少信息量

比如说,某一 θ \theta θ 点附近的函数平面非常陡峭(有一极值峰值),那么我们不需要采样多少 x x x 即可做出比较好的估计,也就是采样点 x x x 的Fisher 信息量较高。反之,若某一 θ \theta θ 附近的函数平面连续且平缓,那么我们需要采样很多点才能做出比较好的估计,也就是 Fisher 信息量较低。

从这一直观定义出发,我们可以联想到随机变量的方差,因此对于一个(假设的)真实参数 θ \theta θ s s s 函数的 Fisher 信息定义为 s s s 函数的方差

I ( θ ) = E [ ( ∂ ∂ θ log ⁡ f ( x ∣ θ ) ) 2 ∣ θ ] = ∫ ( ∂ ∂ θ log ⁡ f ( x ∣ θ ) ) 2 f ( x ; θ ) d x \begin{aligned} \mathcal{I} (\theta) & =\mathbb{E}\left[\left.\left({\frac {\partial }{\partial \theta }}\log f(x|\theta )\right)^{2}\right|\theta \right] \\ & = \int \left({\frac {\partial }{\partial \theta }}\log f(x|\theta )\right)^{2}f(x;\theta )\mathrm{d}x \end{aligned} I(θ)=E[(θlogf(xθ))2θ]=(θlogf(xθ))2f(x;θ)dx

此外,如果 log ⁡ f ( x ∣ θ ) \log f(x|\theta) logf(xθ) 对于 θ \theta θ 二次可微,那么 Fisher 信息还可以写作

I ( θ ) = − E [ ∂ 2 ∂ 2 θ log ⁡ f ( x ∣ θ ) ∣ θ ] \mathcal{I}(\theta) = -\mathbb{E}\left[\left.{\frac {\partial^2}{\partial^2 \theta }}\log f(x|\theta )\right|\theta \right] I(θ)=E[2θ2logf(xθ)θ]

证明如下:

∵ 0 = E [ s ∣ θ ] ∴ 0 = ∂ ∂ θ E [ s ∣ θ ] = ∂ ∂ θ ∫ X f ( x ∣ θ ) ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x = ∫ X ∂ ∂ θ ∂ log ⁡ L ( θ ∣ x ) ∂ θ f ( x ∣ θ )   d x ▹  use chain rule = ∫ X { ∂ 2 log ⁡ L ( θ ∣ x ) ∂ 2 θ f ( x ∣ θ ) + ∂ f ( x ∣ θ ) ∂ θ ∂ log ⁡ L ( θ ∣ x ) ∂ θ } d x = ∫ X ∂ 2 log ⁡ L ( θ ∣ x ) ∂ 2 θ f ( x ∣ θ ) d x ⏟ A + ∫ X ∂ L ( θ ∣ x ) ∂ θ ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x ⏟ B A = E [ ∂ 2 log ⁡ L ( θ ∣ x ) ∂ 2 θ ∣ θ ] B = ∫ X ∂ L ( θ ∣ x ) ∂ θ ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x = ∫ X ∂ log ⁡ L ( θ ∣ x ) ∂ θ L ( θ ∣ x ) ∂ log ⁡ L ( θ ∣ x ) ∂ θ d x = ∫ X ( ∂ log ⁡ L ( θ ∣ x ) ∂ θ ) 2 f ( x ∣ θ ) d x = E [ ( ∂ log ⁡ L ( θ ∣ x ) ∂ θ ) 2 ∣ θ ] ∵ A + B = 0 ∴ E [ ∂ 2 log ⁡ L ( θ ∣ x ) ∂ 2 θ ∣ θ ] + E [ ( ∂ log ⁡ L ( θ ∣ x ) ∂ θ ) 2 ∣ θ ] = 0 \begin{aligned} \because 0 & = \mathbb{E}[s|\theta] \\ \\ \therefore 0 & = \frac{\partial}{\partial\theta} \mathbb{E}[s|\theta] \\ & = \frac{\partial}{\partial\theta}\int_{\mathcal{X}} f(x|\theta) \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x \\ & = \int_{\mathcal{X}} \frac{\partial}{\partial\theta} \boxed{\frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} f(x|\theta)}\ \mathrm{d}x \quad{\triangleright\ \textrm{use chain rule}}\\ & = \int_{\mathcal{X}} \left\{ \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}f(x|\theta) + \frac{\partial f(x|\theta)}{\partial\theta} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}}\right\} \mathrm{d}x \\ & = \underbrace{\int_{\mathcal{X}} \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}f(x|\theta) \mathrm{d}x }_\mathbf{A} + \underbrace{\int_{\mathcal{X}}\frac{\partial \mathcal{L}(\theta|x)}{\partial\theta} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x}_{\mathbf{B}} \\ \\ \mathbf{A} &= \mathbb{E}\left[\left. \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}\right| \theta \right] \\ \mathbf{B} &= \int_{\mathcal{X}} \red{\frac{\partial \mathcal{L}(\theta|x)}{\partial\theta}} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x \\ &= \int_{\mathcal{X}}\red{\frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}}\mathcal{L}(\theta|x)} \frac{\partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \mathrm{d}x\\ &= \int_{\mathcal{X}} \left(\frac{ \partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \right)^2 f(x|\theta)\mathrm{d}x \\ &= \mathbb{E}\left[\left. \left(\frac{ \partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \right)^2 \right| \theta \right] \\ \\ & \because \mathbf{A}+\mathbf{B} = 0 \\ & \therefore \mathbb{E}\left[\left. \frac{\partial^2\log\mathcal{L}(\theta|x)}{\partial^2\theta}\right| \theta \right] + \mathbb{E}\left[\left. \left(\frac{ \partial\log\mathcal{L}(\theta|x)}{\partial{\theta}} \right)^2 \right| \theta \right] = 0 \end{aligned} 00AB=E[sθ]=θE[sθ]=θXf(xθ)θlogL(θx)dx=XθθlogL(θx)f(xθ) dx use chain rule=X{2θ2logL(θx)f(xθ)+θf(xθ)θlogL(θx)}dx=A X2θ2logL(θx)f(xθ)dx+B XθL(θx)θlogL(θx)dx=E[2θ2logL(θx)θ]=XθL(θx)θlogL(θx)dx=XθlogL(θx)L(θx)θlogL(θx)dx=X(θlogL(θx))2f(xθ)dx=E[(θlogL(θx))2θ]A+B=0E[2θ2logL(θx)θ]+E[(θlogL(θx))2θ]=0

二、EWC

1. 数学推导

假设数据集被划分为两个任务 Σ = { A , B } \Sigma = \{\mathcal{A, B}\} Σ={A,B},网络参数为 θ \theta θ

学习任务为最大化后验概率

arg max ⁡ θ P ( θ ∣ Σ ) = arg max ⁡ θ log ⁡ P ( θ ∣ Σ ) = arg min ⁡ θ l ( θ ) \begin{aligned} & \argmax_{\theta}P(\theta | \Sigma) \\ = & \argmax_{\theta} \log P(\theta | \Sigma) \\ = & \argmin_{\theta} l(\theta) \end{aligned} ==θargmaxP(θΣ)θargmaxlogP(θΣ)θargminl(θ)

其中 l ( θ ) l(\theta) l(θ)定义为训练 loss

考虑任务训练顺序 A ⇒ B \mathcal{A} \Rightarrow \mathcal{B} AB

log ⁡ P ( θ ∣ Σ ) = log ⁡ P ( θ ∣ A , B ) = log ⁡ P ( θ , A , B ) − log ⁡ P ( A , B ) = log ⁡ P ( B ∣ θ , A ) + log ⁡ P ( θ , A ) − log ⁡ P ( B ∣ A ) − log ⁡ P ( A ) = log ⁡ P ( B ∣ θ ) + log ⁡ P ( θ ∣ A ) + log ⁡ P ( A ) − log ⁡ P ( B ) − log ⁡ P ( A ) ▹ A, B i.i.d  = log ⁡ P ( B ∣ θ ) ⏟ l o s s   o n   B + log ⁡ P ( θ ∣ A ) ⏟ unknown − log ⁡ P ( B ) ⏟ constant \begin{aligned} & \log P(\theta | \Sigma) \\ &= \log P(\theta | A, B) \\ &= \log P(\theta, A, B) - \log P(A, B) \\ &= \log P(B|\theta, A) + \log P(\theta, A) - \log P(B|A) - \log P(A) \\ &= \log P(B|\theta) + \log P(\theta | A) + \log P(A) - \log P(B) - \log P(A) \qquad \triangleright \textrm{A, B i.i.d } \\ &= \underbrace{\log P(B|\theta)}_{\mathrm{loss\ on\ \mathcal{B}}} + \underbrace{\log P(\theta|A)}_{\textrm{unknown}} - \underbrace{\log P(B)}_{\textrm{constant}} \\ \end{aligned} logP(θΣ)=logP(θA,B)=logP(θ,A,B)logP(A,B)=logP(Bθ,A)+logP(θ,A)logP(BA)logP(A)=logP(Bθ)+logP(θA)+logP(A)logP(B)logP(A)A, B i.i.d =loss on B logP(Bθ)+unknown logP(θA)constant logP(B)

其中后验概率 log ⁡ P ( θ ∣ A ) \log P(\theta | A) logP(θA) 不易得到,因此使用拉普拉斯近似进行分析

在训练任务 B \mathcal{B} B 之前,网络已经在任务 A \mathcal{A} A 上收敛,设网络此时的参数为 θ A ∗ \theta_{A}^* θA,为在任务 A \mathcal{A} A 上拟合得到的参数,设函数 f ( θ ) = log ⁡ P ( θ ∣ A ) f(\theta) = \log P(\theta | A) f(θ)=logP(θA)

f ( θ ) f(\theta) f(θ) θ = θ A ∗ \theta = \theta_{A}^* θ=θA 做泰勒展开:

f ( θ ) = f ( θ A ∗ ) + ∂ f ( θ ) ∂ θ ∣ θ A ∗ ⏟ = 0 ( θ − θ A ∗ ) + 1 2 ( θ − θ A ∗ ) T ∂ 2 f ( θ ) ∂ 2 θ ∣ θ A ∗ ( θ − θ A ∗ ) + ⋯ ≈ f ( θ A ∗ ) + 1 2 ( θ − θ A ∗ ) T ∂ 2 f ( θ ) ∂ 2 θ ∣ θ A ∗ ( θ − θ A ∗ ) \begin{aligned} f(\theta) &= f(\theta_{A}^*) + \underbrace{\left.\frac{\partial f(\theta)}{\partial\theta}\right|_{\theta_A^*}}_{=0}(\theta - \theta_{A}^*) + \frac{1}{2}(\theta - \theta_{A}^*)^T \left.\frac{\partial^2f(\theta)}{\partial^2\theta}\right|_{\theta_A^*} (\theta - \theta_{A}^*) + \cdots \\ &\approx f(\theta_{A}^*) + \frac{1}{2}(\theta - \theta_{A}^*)^T \left.\frac{\partial^2f(\theta)}{\partial^2\theta}\right|_{\theta_A^*} (\theta - \theta_{A}^*) \\ \end{aligned} f(θ)=f(θA)+=0 θf(θ)θA(θθA)+21(θθA)T2θ2f(θ)θA(θθA)+f(θA)+21(θθA)T2θ2f(θ)θA(θθA)

f ( θ ) = log ⁡ P ( θ ∣ A ) f(\theta) = \log P(\theta | A) f(θ)=logP(θA) 代入:

log ⁡ P ( θ ∣ A ) = log ⁡ P ( θ A ∗ ∣ A ) + 1 2 ( θ − θ A ∗ ) T ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ( θ − θ A ∗ ) = log ⁡ P ( θ A ∗ ∣ A ) + 1 2 ( θ − θ A ∗ ) T { − [ − ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ] − 1 } − 1 ( θ − θ A ∗ ) P ( θ ∣ A ) = exp ⁡ [ Δ + 1 2 ( θ − θ A ∗ ) T { − [ − ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ] − 1 } − 1 ( θ − θ A ∗ ) ] = ϵ exp ⁡ [ − 1 2 ( θ − θ A ∗ ) T { [ − ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ] − 1 } − 1 ⏟ Σ − 1 ( θ − θ A ∗ ) ] where: Δ = log ⁡ P ( θ A ∗ ∣ A ) ϵ = exp ⁡ Δ \begin{aligned} \log P(\theta|A) &= \log P(\theta_A^* | A) + \frac{1}{2}(\theta - \theta_{A}^*)^T \left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*} (\theta - \theta_{A}^*) \\ &= \log P(\theta_A^* | A) + \frac{1}{2}(\theta - \theta_{A}^*)^T\left\{-\left[-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right]^{-1}\right\}^{-1} (\theta - \theta_{A}^*) \\ P(\theta|A) &= \exp{\left[ \Delta + \frac{1}{2}(\theta - \theta_{A}^*)^T\left\{-\left[-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right]^{-1}\right\}^{-1} (\theta - \theta_{A}^*) \right]} \\ &= \epsilon\exp{\left[-\frac{1}{2}(\theta - \theta_{A}^*)^T \underbrace{\left\{\left[-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right]^{-1}\right\}^{-1}}_{\Sigma^{-1}} (\theta - \theta_{A}^*) \right]} \\ \textbf{where:} & \\ \Delta &= \log P(\theta_A^* | A) \\ \epsilon &= \exp{\Delta} \end{aligned} logP(θA)P(θA)where:Δϵ=logP(θAA)+21(θθA)T2θ2logP(θA)θA(θθA)=logP(θAA)+21(θθA)T[2θ2logP(θA)θA]11(θθA)=expΔ+21(θθA)T[2θ2logP(θA)θA]11(θθA)=ϵexp21(θθA)TΣ1 [2θ2logP(θA)θA]11(θθA)=logP(θAA)=expΔ

观察形式可得:

P ( θ ∣ A ) ∼ N ( θ A ∗ , ( − ∂ 2 log ⁡ P ( θ ∣ A ) ∂ 2 θ ∣ θ A ∗ ) − 1 ) P(\theta | \mathcal{A}) \sim \mathcal{N}\left(\theta_{\mathcal{A}}^*,\left(-\left.\frac{\partial^2\log P(\theta|A)}{\partial^2\theta}\right|_{\theta_A^*}\right)^{-1}\right) P(θA)NθA,(2θ2logP(θA)θA)1

其中协方差矩阵项正是第一部分讨论的Fisher信息矩阵,记做 I A \mathbf{I}_{\mathcal{A}} IA,则有

P ( θ ∣ A ) ∼ N ( θ A ∗ , [ I A ] − 1 ) P(\theta | \mathcal{A}) \sim \mathcal{N}\left(\theta_{\mathcal{A}}^*,\left[\mathbf{I}_{\mathcal{A}}\right]^{-1} \right) P(θA)N(θA,[IA]1)

另外,EWC是以一个参数的视角出发的,因此Fisher信息矩阵只需要对角线元素,其余计算出来的结果可以置0,所以有:

P ( θ ∣ A ) = 1 ( 2 π ) k ∣ Σ ∣ exp ⁡ { − 1 2 ( θ − θ A ∗ ) T Σ − 1 ( θ − θ A ∗ ) } log ⁡ P ( θ i ∣ A ) = − 1 2 ( θ i − [ θ i ] A ∗ ) 2 ∗ [ Σ − 1 ] i i = − [ I A ] i i ( θ i − [ θ i ] A ∗ ) 2 2 \begin{aligned} P(\theta | A) &= \frac{1}{\sqrt{(2\pi)^k |\Sigma|}} \exp\{-\frac{1}{2}(\theta - \theta_A^*)^T\Sigma^{-1}(\theta - \theta_{A}^*)\}\\ \\ \log P(\theta_{i} | A) &= -\frac{1}{2}(\theta_i - [\theta_i]_A^*)^2 * [\Sigma^{-1}]_{ii} \\ &= -\left[\mathbf{I}_{\mathcal{A}}\right]_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \end{aligned} P(θA)logP(θiA)=(2π)kΣ 1exp{21(θθA)TΣ1(θθA)}=21(θi[θi]A)2[Σ1]ii=[IA]ii2(θi[θi]A)2

所以,所有参数的EWC Loss可定义为:

l EWC = − ∑ i = 1 #Params [ I A ] i i ( θ i − [ θ i ] A ∗ ) 2 2 \begin{aligned} l_{\textbf{EWC}} = -\sum_{i=1}^{\textrm{\#Params}} \left[\mathbf{I}_{\mathcal{A}}\right]_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \end{aligned} lEWC=i=1#Params[IA]ii2(θi[θi]A)2

将上述内容代入总优化目标:

l ( θ ) = log ⁡ P ( θ ∣ Σ ) = log ⁡ P ( B ∣ θ ) + log ⁡ P ( θ ∣ A ) − log ⁡ P ( B ) ⇒ l CE ( θ ∣ B ) + l EWC ( θ ∣ A ) \begin{aligned} l(\theta) &= \log P(\theta | \Sigma) \\ &= \log P(B|\theta) + \log P(\theta|A) - \log P(B) \\ &\Rightarrow l_{\textbf{CE}} (\theta | \mathcal{B}) + l_{\textbf{EWC}}(\theta | \mathcal{A}) \end{aligned} l(θ)=logP(θΣ)=logP(Bθ)+logP(θA)logP(B)lCE(θB)+lEWC(θA)

定义超参数 λ \lambda λ 进行稳定性-可塑性权衡

l ( θ ) = l CE ( θ ∣ B ) + λ ⋅ l EWC ( θ ∣ A ) \begin{aligned} l(\theta) &= l_{\textbf{CE}} (\theta | \mathcal{B}) + \lambda \cdot l_{\textbf{EWC}}(\theta | \mathcal{A}) \end{aligned} l(θ)=lCE(θB)+λlEWC(θA)

因此优化目标为:

arg min ⁡ θ l ( θ ) = arg min ⁡ θ { l CE ( θ ∣ B ) + λ ⋅ l EWC ( θ ∣ A ) } = arg min ⁡ θ { l CE ( θ ∣ B ) − λ 2 ∑ i = 1 #Params [ I A ] i i ( θ i − [ θ i ] A ∗ ) 2 2 } \begin{aligned} \argmin_{\theta}l(\theta) &= \argmin_{\theta}\left\{l_{\textbf{CE}} (\theta | \mathcal{B}) + \lambda \cdot l_{\textbf{EWC}}(\theta | \mathcal{A})\right\} \\ &= \argmin_{\theta} \left\{ l_{\textbf{CE}} (\theta | \mathcal{B}) - \frac{\lambda}{2}\sum_{i=1}^{\textrm{\#Params}} \left[\mathbf{I}_{\mathcal{A}}\right]_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \right\} \end{aligned} θargminl(θ)=θargmin{lCE(θB)+λlEWC(θA)}=θargmin{lCE(θB)2λi=1#Params[IA]ii2(θi[θi]A)2}

2. 如何计算 Fisher 信息矩阵

将训练过程进行划分:

  1. 使用数据集 A \mathcal{A} A l CE ( θ ∣ A ) l_{\textbf{CE}}(\theta|\mathcal{A}) lCE(θA) 训练模型
  2. 保存此时的参数,即 θ A ∗ \theta_{\mathcal{A}}^* θA ,并计算 Fisher 信息矩阵 I A \mathbf{I}_{\mathcal{A}} IA
  3. 使用数据集 B \mathcal{B} B l CE ( θ ∣ B ) l_{\textbf{CE}}(\theta|\mathcal{B}) lCE(θB) l EWC ( θ ∣ A ) l_{\textbf{EWC}}(\theta | \mathcal{A}) lEWC(θA) 训练模型
  4. 多任务 { A , B , C , … } \{\mathcal{A, B, C, \dots}\} {A,B,C,}同理

最后一个问题,如何使用 A \mathcal{A} A 计算 I A \mathbf{I}_{\mathcal{A}} IA

考虑定义

I ( θ ) = E [ ( ∂ ∂ θ log ⁡ f ( x ∣ θ ) ) 2 ∣ θ ] \mathcal{I}(\theta) = \mathbb{E}\left[\left.\left({\frac {\partial }{\partial \theta }}\log f(x|\theta )\right)^{2}\right|\theta \right] \\ I(θ)=E[(θlogf(xθ))2θ]

可以通过计算梯度的平方来获得每一个参数的 Fisher 信息矩阵项:

I ( θ ) = 1 N ∑ ( x , y ) i ∈ A ( ∂ l LL ( θ ∣ ( x , y ) i ) ∂ θ ⏟ Gradient ) 2 \mathcal{I}(\theta) = \frac{1}{N} \sum_{(x,y)_{i} \in \mathcal{A}}\left(\underbrace{\frac{\partial{l_{\textbf{LL}}(\theta|(x, y)_{i})}}{\partial\theta}}_{\textrm{Gradient}}\right)^2 I(θ)=N1(x,y)iAGradient θlLL(θ(x,y)i)2

具体来说,可以向模型逐个喂入样本,并计算损失函数,使用神经网络框架自动计算梯度。对于每个参数,累加所有的梯度,最后除以样本数量,即可得到对应参数的 Fisher 信息矩阵项

需要注意的是,当使用 nn.CrossEntrypyLossnn.NLLLoss时,由于其中对于 Log-Likelihood \textrm{Log-Likelihood} Log-Likelihood使用了相反数处理,使用该类损失函数得到的矩阵是真实 Fisher 信息矩阵的相反数,即 I ^ ( θ ) = − I ( θ ) \hat{\mathcal{I}}(\theta) = -\mathcal{I}(\theta) I^(θ)=I(θ),在计算 loss 时要记得将减号改成加号,即

arg min ⁡ θ l ( θ ) = arg min ⁡ θ { l CE ( θ ∣ B ) + λ 2 ∑ i = 1 #Params [ I A ] ^ i i ( θ i − [ θ i ] A ∗ ) 2 2 } \argmin_{\theta}l(\theta) = \argmin_{\theta} \left\{ l_{\textbf{CE}} (\theta | \mathcal{B}) + \frac{\lambda}{2}\sum_{i=1}^{\textrm{\#Params}} \red{\hat{\left[\mathbf{I}_{\mathcal{A}}\right]}}_{ii}\frac{(\theta_i - [\theta_i]_A^*)^2}{2} \right\} θargminl(θ)=θargmin{lCE(θB)+2λi=1#Params[IA]^ii2(θi[θi]A)2}

全文完

版权声明:自由转载-非商用-禁止演绎-保持署名 4.0 国际许可协议

  • 32
    点赞
  • 78
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
### 回答1: PyTorch是一个开源的深度学习框架,它为持续学习提供了很好的支持。持续学习是指通过不断地学习新的数据、调整模型和继续训练,从而实现模型的优化和更新。下面是使用PyTorch实现持续学习的一些关键步骤: 1. 数据处理:将新的数据加载到PyTorch中,并进行预处理操作,例如数据标准化、数据增强等。可以使用PyTorch中的数据加载器(DataLoader)和数据预处理工具(transform)加快处理过程。 2. 模型加载:加载已经训练好的模型参数,可以使用PyTorch的torch.load()函数加载先前训练模型的参数。 3. 模型调整:根据新的数据特点,对模型进行微调或调整。可以使用PyTorch提供的灵活的模型定义和修改方式,例如修改模型的层结构、修改激活函数等。 4. 优化器选择:选择合适的优化器,例如Adam、SGD等,以在持续学习过程中调整模型的权重。 5. 训练过程:使用新的数据对模型进行训练,并反复迭代调整模型。可以使用PyTorch提供的自动微分功能,加快梯度计算和模型更新过程。 6. 模型保存:在每次训练迭代结束后,保存模型的最新参数。可以使用PyTorch的torch.save()函数保存模型参数。 7. 持续学习:重复上述步骤,对新的数据进行处理、模型调整和训练过程,以实现模型的持续学习。 通过上述步骤,使用PyTorch可以实现持续学习的过程。凭借其灵活性和强大的计算能力,PyTorch能够满足各种深度学习模型对于持续学习的需求,并为模型的优化提供支持。同时,PyTorch还提供了丰富的工具和函数,帮助开发者更高效地实现持续学习。 ### 回答2: PyTorch是一个开源的机器学习框架,它提供了丰富的工具和功能来支持持续学习持续学习指的是通过新数据的输入,持续改进和更新现有的模型,以适应不断变化的环境和任务。 PyTorch提供了一个灵活和可扩展的架构,使得持续学习变得更加容易。以下是在PyTorch中实现持续学习的一些关键步骤: 1. 数据管理:持续学习需要处理不断变化的数据。PyTorch中的DataLoader和Dataset类可以帮助加载和管理数据集。您可以创建一个数据加载器来批量加载新的数据集,并将其与之前的数据集合并。 2. 模型更新:当有新的数据到达时,您可以使用PyTorch的优化器来更新模型的参数,以适应新的数据。您可以使用反向传播算法计算损失,并调用优化器的`step`函数来更新模型的参数。 3. 继续训练:持续学习意味着在之前训练的基础上继续学习。您可以加载之前训练保存的模型,并在新的数据上进行训练。在PyTorch中,您可以使用`torch.load`函数加载之前训练的模型,并通过调用`train`函数来继续训练。 4. 模型评估:持续学习需要在新的数据上进行模型评估,以评估其性能和适应能力。您可以使用PyTorch中的评估函数和指标来评估模型的准确性和效果。 5. 灵活性:PyTorch的灵活性使得您可以自定义和调整模型结构,以适应不同的任务和数据。您可以根据新的数据特点调整模型的层次、结构和参数。 总之,PyTorch为持续学习提供了丰富的功能和易用的工具。通过管理数据、更新模型、继续训练和模型评估,您可以在PyTorch中有效地实现持续学习。 ### 回答3: 在PyTorch中实现持续学习的关键是使用动态图的特性和灵活的模型更新方法。 首先,PyTorch的动态图机制允许我们在运行时构建和修改模型图,这使得持续学习更加容易。我们可以将新的数据集添加到已经训练的模型上,并通过反向传播来更新模型的权重。这样,我们可以通过在已有模型上继续训练来逐步适应新的数据,而无需重新训练整个模型。 其次,持续学习的另一个重要问题是防止旧知识的遗忘。为了解决这个问题,我们可以使用增量学习方法,如Elastic Weight Consolidation(EWC)或Online Deep Learning(ODL)。这些方法通过使用正则化项或定义损失函数来限制新训练数据对旧知识的影响,从而保护旧有的模型参数。 此外,我们还可以使用PyTorch提供的模型保存和加载功能来实现持续学习。我们可以定期保存模型的参数和优化器状态,以便在需要时恢复模型,并继续训练过程。通过这种方式,我们可以持续积累更多的数据和知识,而无需从头开始每次都重新训练模型。 总的来说,PyTorch提供了灵活的动态图和丰富的工具,使得实现持续学习变得简单。我们可以通过动态修改模型图、使用增量学习方法来应对新数据和旧知识的挑战,并使用模型保存和加载功能来持续积累数据和知识。这些方法的组合可以帮助我们在PyTorch中实现高效的持续学习

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值