code:https://github.com/facebookresearch/vicreg
Math and Pseudocode Description
Now that we understand the method conceptually, we can dig into the math. If you're someone who thinks better in code, we also include actual snippets from our PyTorch implementation.
In the equations below, let ZZ be the n \times dn×d matrix representing a batch, where nn and dd are the batch size and embedding dimension, respectively. Let z_{i:}zi: be the iith vector in the batch and let z_{:j}z:j be a vector composed of the jjth element of each vector in the batch.
Variance
The variance term v(Z)v(Z) captures the variance of each embedding variable over a batch:
\text{Var}(z_{:j}) = \frac{1}{n-1}\displaystyle{\sum_{i=1}^n}(z_{ij}-\bar{z}_{j})^2, \ \ \bar{z}_{j} = \frac{1}{n}\displaystyle{\sum_{i=1}^n}z_{ij} \\v(Z) = \frac{1}{d}\displaystyle{\sum_{j=1}^{d}}\max(0,\gamma-\sqrt{\text{Var}(z_{:j})+\epsilon}Var(z:j)=n−11i=1∑n(zij−zˉj)2, zˉj=n1i=1∑nzijv(Z)=d1j=1∑dmax(0,γ−Var(z:j)+ϵ
where \gammaγ is the target value for the standard deviation (they choose \gamma = 1γ=1), and \epsilonϵ is a small scalar put in place to prevent numerical instabilities (they choose \epsilon = 0.0001ϵ=0.0001).
Notice that minimizing v(Z)v(Z) means forcing the batch-wise standard deviation to be above \gammaγ. As soon as this target is achieved, v(Z)v(Z) bottoms out at 00. A hinge function is used here because the point is not to encourage ever-increasing variance; higher variance isn't necessarily better, it just needs to be above a certain threshold to avoid catastrophic failure i.e. mode collapse.
In PyTorch code:
# variance loss
std_z_a = torch.sqrt(z_a.var(dim=0) + self.hparams.variance_loss_epsilon)
std_z_b = torch.sqrt(z_b.var(dim=0) + self.hparams.variance_loss_epsilon)
loss_v_a = torch.mean(F.relu(1 - std_z_a))
loss_v_b = torch.mean(F.relu(1 - std_z_b))
loss_var = loss_v_a + loss_v_b
Invariance
The invariance term s(Z,Z')s(Z,Z′) captures the invariance between positive pairs of embedding vectors:
s(Z,Z') = \frac{1}{n}\displaystyle{\sum_i}||z_i-z'_i||_2^2s(Z,Z′)=n1i∑∣∣zi−zi′∣∣22
This is just a simple mean-squared Euclidean distance metric. Notably, the zz vectors are un-normalized. In the paper, the authors do some experiments using the cosine similarity metric of SimSiam (which has the effect of projecting the vectors onto the unit sphere) instead. They find that performance drops a bit with this type of loss term, and argue that it's too restrictive, especially since their covariance regularization term already prevents dimension collapse.
In PyTorch code:
# invariance loss
loss_inv = F.mse_loss(z_a, z_b)
Covariance
The covariance term c(Z)c(Z) captures the covariance between pairs of embedding dimensions:
C(Z) = \frac{1}{n-1}\displaystyle{\sum_{i=1}^n}(z_{i:} - \bar{z}_{i:} )(z_{i:} - \bar{z}_{i:})^T, \ \ \bar{z}_{i:} = \frac{1}{n}\displaystyle{\sum_{i=1}^n}z_{i:} \\c(Z) = \frac{1}{d}\displaystyle{\sum_{\ell \neq m}}C(Z)^2_{\ell m}C(Z)=n−11i=1∑n(zi:−zˉi:)(zi:−zˉi:)T, zˉi:=n1i=1∑nzi:c(Z)=d1ℓ=m∑C(Z)ℓm2
This one can be a bit tough to wrap your mind around dimensionally. Note that z_{i:}zi: and \bar{z}_{i:}zˉi: are both vectors of length dd, resulting in d \times dd×d covariance matrix CC. Whereas \text{Var}(z_{:j})Var(z:j) returns a number for each column vector z_{:j}z:j, C(Z)_{\ell m}C(Z)ℓm returns a number for the covariance between the centered versions of z_{:\ell}z:ℓ and z_{:m}z:m. Minimizing c(Z)c(Z) means minimizing the off-diagonal components of the covariance matrix between centered embedding variables.
In PyTorch code:
# covariance loss
N, D = z_a.shape
z_a = z_a - z_a.mean(dim=0)
z_b = z_b - z_b.mean(dim=0)
cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD
cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD
loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
loss_cov = loss_c_a + loss_c_b
Combined Loss Function
The loss function is a weighted combination of these three terms:
\mathcal{L} = \displaystyle{\sum_{i\in\mathcal{D}}\sum_{t' \sim \mathcal{T}}}[\lambda s(Z,Z') + \mu\{v(Z)+v(Z') \} + \nu\{c(Z)+c(Z')\}]L=i∈D∑t′∼T∑[λs(Z,Z′)+μ{v(Z)+v(Z′)}+ν{c(Z)+c(Z′)}]
where \lambda, \mu, \nuλ,μ,ν are hyper-parameters (set to \lambda = \mu = 25, \nu = 1λ=μ=25,ν=1 in the paper for the baseline) and the summations are over images ii and augmentations t't′.
In PyTorch code:
weighted_inv = loss_inv * self.hparams.invariance_loss_weight
weighted_var = loss_var * self.hparams.variance_loss_weight
weighted_cov = loss_cov * self.hparams.covariance_loss_weight
loss = weighted_inv + weighted_var + weighted_cov