Long-Tailed Classification by Keeping the Good and Removing the Bad Momentum Causal Effect
M代表torch中实现SGD中的momentum算法。momentum算法引入了动量的概念,类似于惯性,其目的是让梯度下降的更加稳定;但是它随着训练过程一直累积着,相当于LSTM的时序信息一样,将其他batch继承了过来,影响着参数的调整,从而影响特征提取的过程,从而影响X。
D表示特征X在向head的特征的方向上偏离的投影。M是影响X在向head的特征偏移的元凶,由于X在向head的特征偏移所以导致Y对于long-tail的loss很大。
计算 P ( Y ∣ X ) P(Y|X) P(Y∣X)时存在一条后门路径X<-M->D->Y,因此需要 d o ( X ) do(X) do(X)来消除后门路径的影响。
P ( Y = i ∣ d o ( X = x ) ) = ∑ m ∈ M ( Y = i ∣ X = x , M = m ) P ( M = m ) = ∑ m ∈ M P ( Y = i , X = x ∣ M = m ) P ( X = x ∣ M = m ) \begin{aligned} P(Y=i|do(X=x))&=\displaystyle \sum_{m \in M}(Y=i|X=x,M=m)P(M=m) \\ &=\displaystyle \sum_{m \in M}\frac{P(Y=i,X=x|M=m)}{P(X=x|M=m)} \end{aligned} P(Y=i∣do(X=x))=m∈M∑(Y=i∣X=x,M=m)P(M=m)=m∈M∑P(X=x∣M=m)P(Y=i,X=x∣M=m)
d o ( X ) do(X) do(X)后的因果图是: