Understanding Self-Supervised Learning Dynamics without Contrastive Pairs论文笔记
Two layer linear model
Points
- weight decay can balance the predictor and online network weights
- stop gradient can prevent collapsing
- stop gradient with no predictor will lead to collapsing.
Model
W p ∈ R n 2 × n 2 , W ∈ R n 2 × n 1 , W a ∈ R n 2 × n 1 , x ∈ R n 1 . f 1 = W x 1 ∈ R n 2 , f 2 a = W a x 2 ∈ R n 2 W_p \in \mathbb{R}^{n_2 \times n_2},W \in \mathbb{R}^{n_2 \times n_1},W_a \in \mathbb{R}^{n_2 \times n_1},x \in \mathbb{R}^{n_1}.{f}_1 = W x_1 \in \mathbb{R}^{n_2}, f_{2a}= W_a x_2 \in \mathbb{R}^{n_2} Wp∈Rn2×n2,W∈Rn2×n1,Wa∈Rn2×n1,x∈Rn1.f1=Wx1∈Rn2,f2a=Wax2∈Rn2
x 1 , x 2 x_1,x_2 x1,x2 is augmented views.
Loss Function
Gradient
η \eta η: weight decay, X = E x [ x ˉ x ˉ T ] , x ˉ ( x ) : = E x ′ ∼ p a u g ( ⋅ ∣ x ) [ x ′ ] , X ′ = E x [ V x ′ ∣ x [ x ′ ] ] X = \mathbb E_x[\bar x \bar x^T],\ \bar x(x) := \mathbb E_{x' \sim p_{aug}(\cdot|x)} [x'],\ X' = \mathbb E_x[\mathbb V_{x'|x}[x']] X=Ex[xˉxˉT], xˉ(x):=Ex′∼paug(⋅∣x)[x′], X′=Ex[Vx′∣x[x′]], x ′ x' x′ is augmented view of x x x
Proof
-
weight decay can balance the predictor and online network weights (by removing W a W_a Wa)
α p − 1 [ e 2 η p t W p T W p − W p T ( 0 ) W p ( 0 ) ] = e 2 η t W W T − W ( 0 ) W T ( 0 ) \alpha_p^{-1} [e^{2 \eta_p t } W_p^T W_p - W_p^T(0) W_p(0)] = e^{2 \eta t } WW^T - W(0)W^T(0) αp−1[e2ηptWpTWp−WpT(0)Wp(0)]=e2ηtWWT−W(0)WT(0)
-
stop gradient can prevent collapsing
H ( t ) : = X ′ ⊗ ( W p T W p + I ) + X ⊗ W ~ p T W ~ p + η I n 1 n 2 , W p = W p − I n 2 H ( t ) : = X ′ ⊗ ( W p T W p + I ) + X ⊗ ( W p T W p − 2 W p + I ) + η I n 1 n 2 H(t) := X' \otimes (W_p^T W_p + I) + X \otimes \tilde W_p^T \tilde W_p + \eta I_{n_1 n_2},\ W_p = W_p - I_{n_2} \\ H(t) := X' \otimes (W_p^T W_p + I) + X \otimes (W_p^T W_p - 2 W_p + I) + \eta I_{n_1 n_2} H(t):=X′⊗(WpTWp+I)+X⊗W~pTW~p+ηIn1n2, Wp=Wp−In2H(t):=X′⊗(WpTWp+I)+X⊗(WpTWp−2Wp+I)+ηIn1n2
-
stop gradient with no predictor will lead to collapsing.
W p = I W_p = I Wp=I
Multiple factors analysis
Three Assumptions are made to decouple gradients to scalar.
Assumptions
Assumption 1 (Proportional EMA)
experiments validation
Assumption 2 (Isotropic(各向同性,等方差) data and augmentation).
Data distribution p ( x ) p(x) p(x) has zero mean and identity covariance
Augmentation distribution p a u g ( ⋅ ∣ x ) p_{aug} (\cdot|x) paug(⋅∣x) has mean x x x and covariance σ 2 I \sigma ^2 I σ2I.
$X = I, X’ = $ σ 2 I \sigma ^2 I σ2I
(Previous work)
Assumption 3 (symmetric predictor W p W_p Wp)
W p = W p T W_p = W_p^T Wp=WpT
-
Motivation
- fixed point ( W p ˙ = 0 \dot{W_p} = 0 Wp˙=0) is symmetric (under some occasions) (particular η , W W T \eta, W W^T η,WWT)
- Under Assumption 1 and 2 the asymmetry part ( W p − W p T W_p - W_p^T Wp−WpT) vanishes. (particular η , τ \eta, \tau η,τ)
-
Experiments Phenomena
- BYOL: symmetric > reg (slightly better)
- Simsiam: symmetric fails (why?)
η ˉ = 0.0004 , α p = 1 \bar \eta = 0.0004, \alpha_p = 1 ηˉ=0.0004,αp=1
Conclusion
Under above assumptions, the eigenspace (eigenvectors) of F F F and W p W_p Wp gradually align. (experiments validation Fig. 9)
(when η \eta η is small or zero and τ \tau τ is large, the alignment will vanish.)
F
:
=
W
X
W
T
F:= WXW^T
F:=WXWT. Note that
F
F
F is the correlation matrix of the
W
x
1
Wx_1
Wx1. By Assumption 2,
E
[
x
]
=
0
\mathbb{E}[x] = 0
E[x]=0 and F is also the covariance matrix.(?)
W
p
=
U
Λ
W
p
U
T
,
F
=
U
Λ
F
U
T
,
W
p
˙
=
U
G
1
U
T
,
F
˙
=
U
G
2
U
T
(
U
˙
=
0
)
Λ
W
p
=
diag
[
p
1
,
p
2
,
.
.
.
,
p
d
]
,
Λ
F
=
diag
[
s
1
,
s
2
,
.
.
.
,
s
d
]
W_p = U \Lambda_{W_p} U^T, F = U \Lambda_{F} U^T, \dot{W_p} = U G_1 U^T, \dot F = U G_2 U^T \ \ (\dot U = 0) \\ \Lambda_{W_p} = \text{diag}[p_1,p_2,...,p_d], \Lambda_F = \text{diag}[s_1,s_2,...,s_d]
Wp=UΛWpUT,F=UΛFUT,Wp˙=UG1UT,F˙=UG2UT (U˙=0)ΛWp=diag[p1,p2,...,pd],ΛF=diag[s1,s2,...,sd]
带入梯度式子后得到
Analysis on α p , η , β \alpha_p, \eta, \beta αp,η,β
(relative predictor learning rate, weight decay, EMA parameter.)
-
From Eqn. 11 and Eqn. 12 (remove τ \tau τ), we have
c j = s j ( 0 ) − α p − 1 p j 2 ( 0 ) c_j = s_j(0) - \alpha_p^{-1} p_j^2(0) cj=sj(0)−αp−1pj2(0)
和前面的某个式子很像- η = 0 \eta = 0 η=0时, α p \alpha_p αp 太小, c j < 0 c_j \lt 0 cj<0, s j → 0 s_j \rightarrow 0 sj→0, might collapse.
- η = 0 \eta = 0 η=0时, α p \alpha_p αp 太大, s j ( + ∞ ) = s j ( 0 ) s_j(+ \infty) = s_j(0) sj(+∞)=sj(0)
- η > 0 \eta \gt 0 η>0, s j = α p − 1 p j 2 s_j = \alpha_p^{-1} p_j^2 sj=αp−1pj2 balance.
-
η > 0 \eta \gt 0 η>0时,把 s j = α p − 1 p j 2 s_j = \alpha_p^{-1} p_j^2 sj=αp−1pj2带入Eqn. 11得到
p ˙ j = p j Δ j Δ j : = p j [ τ − ( 1 + σ 2 ) p j ] − η \dot p_j = p_j \Delta_j\\ \Delta_j := p_j[\tau - (1+\sigma^2) p_j] - \eta p˙j=pjΔjΔj:=pj[τ−(1+σ2)pj]−η
令 p ˙ j = 0 \dot p_j = 0 p˙j=0有,
p ˙ j \dot p_j p˙j 图像如下, p ˙ j − \dot p_{j-} p˙j− unstable, p ˙ j 0 \dot p_{j0} p˙j0 collapse .
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-h6cZU5Hz-1626077314449)(图片数据/屏幕截图 2021-06-22 220434.png)]
- larger η \eta η, smaller τ \tau τ, lead p ˙ j − \dot p_{j-} p˙j− to right and the basin of collapse expands.
- when η > τ 2 4 ( 1 + σ 2 ) \eta \gt \frac{\tau^2}{4(1+\sigma^2)} η>4(1+σ2)τ2, collapsing is unavoidable.
-
To satisfy eigenspace alignment, below equation has to hold.
Δ j : = p j [ τ − ( 1 + σ 2 ) p j ] − η \Delta_j := p_j[\tau - (1+\sigma^2) p_j] - \eta Δj:=pj[τ−(1+σ2)pj]−η
so larger η \eta η and α p \alpha_p αp can loosen the bound.
-
Curriculum learning: Initially, p j p_j pj and s j s_j sj is small, and since W W W changes rapidly, τ \tau τ is also small. When p j p_j pj approaches its stable fixed point p j + p_j^+ pj+, then p j p_j pj and s j s_j sj stop growing, making τ \tau τ larger, and set a higher p j + p_j^+ pj+.
-
α p > 1 , η p > η s \alpha_p \gt 1 , \eta_p \gt \eta_s αp>1,ηp>ηs, can make symmetric W p W_p Wp work without EMA (Simsiam) ? (更容易满足特征空间的对齐性质?)
Motivation
- satisfy eigenspace alignment directly
- initialize W p W_p Wp outside the basin of collapse.
Implementation
estimate F ^ \hat F F^ by moving average
f = W x f = W x f=Wx, E B [ ⋅ ] \mathbb E_B[\cdot] EB[⋅] is the expectation over a batch
Problem
most analysis is under symmetric W p W_p Wp, but actually regular W p W_p Wp can also work.