VICREG: VARIANCE-INVARIANCE-COVARIANCE REGULARIZATIONFOR SELF-SUPERVISED LEARNING

code:https://github.com/facebookresearch/vicreg

 

Math and Pseudocode Description

Now that we understand the method conceptually, we can dig into the math. If you're someone who thinks better in code, we also include actual snippets from our PyTorch implementation.

In the equations below, let ZZ be the n \times dn×d matrix representing a batch, where nn and dd are the batch size and embedding dimension, respectively. Let z_{i:}zi:​ be the iith vector in the batch and let z_{:j}z:j​ be a vector composed of the jjth element of each vector in the batch.

Variance

The variance term v(Z)v(Z) captures the variance of each embedding variable over a batch:

\text{Var}(z_{:j}) = \frac{1}{n-1}\displaystyle{\sum_{i=1}^n}(z_{ij}-\bar{z}_{j})^2, \ \ \bar{z}_{j} = \frac{1}{n}\displaystyle{\sum_{i=1}^n}z_{ij} \\v(Z) = \frac{1}{d}\displaystyle{\sum_{j=1}^{d}}\max(0,\gamma-\sqrt{\text{Var}(z_{:j})+\epsilon}Var(z:j​)=n−11​i=1∑n​(zij​−zˉj​)2,  zˉj​=n1​i=1∑n​zij​v(Z)=d1​j=1∑d​max(0,γ−Var(z:j​)+ϵ​

where \gammaγ is the target value for the standard deviation (they choose \gamma = 1γ=1), and \epsilonϵ is a small scalar put in place to prevent numerical instabilities (they choose \epsilon = 0.0001ϵ=0.0001).

Notice that minimizing v(Z)v(Z) means forcing the batch-wise standard deviation to be above \gammaγ. As soon as this target is achieved, v(Z)v(Z) bottoms out at 00. A hinge function is used here because the point is not to encourage ever-increasing variance; higher variance isn't necessarily better, it just needs to be above a certain threshold to avoid catastrophic failure i.e. mode collapse.

In PyTorch code:

 

# variance loss

std_z_a = torch.sqrt(z_a.var(dim=0) + self.hparams.variance_loss_epsilon)

std_z_b = torch.sqrt(z_b.var(dim=0) + self.hparams.variance_loss_epsilon)

loss_v_a = torch.mean(F.relu(1 - std_z_a))

loss_v_b = torch.mean(F.relu(1 - std_z_b))

loss_var = loss_v_a + loss_v_b

Invariance

The invariance term s(Z,Z')s(Z,Z′) captures the invariance between positive pairs of embedding vectors:

s(Z,Z') = \frac{1}{n}\displaystyle{\sum_i}||z_i-z'_i||_2^2s(Z,Z′)=n1​i∑​∣∣zi​−zi′​∣∣22​

This is just a simple mean-squared Euclidean distance metric. Notably, the zz vectors are un-normalized. In the paper, the authors do some experiments using the cosine similarity metric of SimSiam (which has the effect of projecting the vectors onto the unit sphere) instead. They find that performance drops a bit with this type of loss term, and argue that it's too restrictive, especially since their covariance regularization term already prevents dimension collapse.

In PyTorch code:

 

# invariance loss

loss_inv = F.mse_loss(z_a, z_b)

Covariance

The covariance term c(Z)c(Z) captures the covariance between pairs of embedding dimensions:

C(Z) = \frac{1}{n-1}\displaystyle{\sum_{i=1}^n}(z_{i:} - \bar{z}_{i:} )(z_{i:} - \bar{z}_{i:})^T, \ \ \bar{z}_{i:} = \frac{1}{n}\displaystyle{\sum_{i=1}^n}z_{i:} \\c(Z) = \frac{1}{d}\displaystyle{\sum_{\ell \neq m}}C(Z)^2_{\ell m}C(Z)=n−11​i=1∑n​(zi:​−zˉi:​)(zi:​−zˉi:​)T,  zˉi:​=n1​i=1∑n​zi:​c(Z)=d1​ℓ=m∑​C(Z)ℓm2​

This one can be a bit tough to wrap your mind around dimensionally. Note that z_{i:}zi:​ and \bar{z}_{i:}zˉi:​ are both vectors of length dd, resulting in d \times dd×d covariance matrix CC. Whereas \text{Var}(z_{:j})Var(z:j​) returns a number for each column vector z_{:j}z:j​, C(Z)_{\ell m}C(Z)ℓm​ returns a number for the covariance between the centered versions of z_{:\ell}z:ℓ​ and z_{:m}z:m​. Minimizing c(Z)c(Z) means minimizing the off-diagonal components of the covariance matrix between centered embedding variables.

In PyTorch code:

 

# covariance loss

N, D = z_a.shape

z_a = z_a - z_a.mean(dim=0)

z_b = z_b - z_b.mean(dim=0)

cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD

cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD

loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D

loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D

loss_cov = loss_c_a + loss_c_b

Combined Loss Function

The loss function is a weighted combination of these three terms:

\mathcal{L} = \displaystyle{\sum_{i\in\mathcal{D}}\sum_{t' \sim \mathcal{T}}}[\lambda s(Z,Z') + \mu\{v(Z)+v(Z') \} + \nu\{c(Z)+c(Z')\}]L=i∈D∑​t′∼T∑​[λs(Z,Z′)+μ{v(Z)+v(Z′)}+ν{c(Z)+c(Z′)}]

where \lambda, \mu, \nuλ,μ,ν are hyper-parameters (set to \lambda = \mu = 25, \nu = 1λ=μ=25,ν=1 in the paper for the baseline) and the summations are over images ii and augmentations t't′.

In PyTorch code:

 

weighted_inv = loss_inv * self.hparams.invariance_loss_weight

weighted_var = loss_var * self.hparams.variance_loss_weight

weighted_cov = loss_cov * self.hparams.covariance_loss_weight

loss = weighted_inv + weighted_var + weighted_cov

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值