背景
异常检测中有一类方法是基于统计神经网络中特征的统计,需要计算均值和协方差,新的输入计算与此统计值的马氏距离来表示异常程度。
统计需要样本,样本过多,或者样本分批次进行计算的时候,如何维护这两个统计值就有探究一下的必要了。有博客是将如何计算增量数据下的均值、方差的变化。所以这里只推一下协方差的增量数据下的变化。
协方差的计算
c o v ( X , Y ) = ∑ i = 1 n ( X i − X ‾ ) ( Y i − Y ‾ ) n − 1 cov(X,Y)=\sum_{i=1}^n \frac{(X_i- \overline X)(Y_i- \overline Y)}{n-1} cov(X,Y)=i=1∑nn−1(Xi−X)(Yi−Y)
数据维度和计算
假设需要计算的数据X是二维数据,具体如图所示,原数据是k*n,和自己计算协方差,相应的均值是
H
‾
\overline H
H,维度是(1*k)。新增数据是k*m,相应的均值是
A
‾
\overline A
A,维度是(1*k)。m和n分别代表样本数,每个样本有k维的特征长度。对n和m样本总的均值是
X
‾
\overline X
X,维度是(1*k).
且是对应位置相乘,继续推导如下,令
Y
=
X
T
Y =X^T
Y=XT:
c
o
v
(
X
,
Y
)
=
∑
i
=
1
n
+
m
(
X
i
−
X
‾
)
(
Y
i
−
Y
‾
)
n
−
1
cov(X,Y)=\sum_{i=1}^{n+m} \frac{(X_i- \overline X)(Y_i- \overline Y)}{n-1}
cov(X,Y)=i=1∑n+mn−1(Xi−X)(Yi−Y)
=
∑
i
=
1
n
(
X
i
−
X
‾
)
(
Y
i
−
Y
‾
)
n
+
m
−
1
+
∑
i
=
1
m
(
X
i
−
X
‾
)
(
Y
i
−
Y
‾
)
n
+
m
−
1
=
∑
i
=
1
n
(
(
X
i
−
H
‾
)
−
(
X
‾
−
H
‾
)
)
∗
∑
j
=
1
n
(
(
X
j
−
H
‾
j
)
−
(
X
‾
j
−
H
‾
j
)
)
n
+
m
−
1
+
∑
i
=
1
m
.
.
.
=
(
∑
i
=
1
n
(
(
X
i
−
H
‾
i
)
∗
(
∑
j
=
1
n
(
X
j
−
H
‾
j
)
)
+
∑
i
=
1
n
(
(
X
i
−
H
‾
)
∗
∑
j
=
1
n
(
X
‾
j
−
H
‾
j
)
+
∑
i
=
1
n
(
X
‾
−
H
‾
)
∗
∑
j
=
1
n
(
(
X
j
−
H
‾
j
)
)
+
∑
i
=
1
n
(
X
‾
−
H
‾
)
∗
∑
j
=
1
n
(
X
‾
j
−
H
‾
j
)
)
∗
1
n
+
m
−
1
+
∑
i
=
1
m
.
.
.
=
(
(
n
−
1
)
∗
c
o
v
(
X
n
,
X
n
T
)
+
∑
i
=
1
n
(
(
X
‾
−
H
‾
)
∗
∑
j
=
1
n
(
X
‾
j
−
H
‾
j
)
)
)
∗
1
n
+
m
−
1
+
(
(
m
−
1
)
∗
c
o
v
(
X
m
,
X
m
T
)
+
∑
i
=
1
m
(
(
X
‾
−
A
‾
)
∗
∑
j
=
1
m
(
X
‾
j
−
A
‾
j
)
)
)
∗
1
n
+
m
−
1
=\sum_{i=1}^n \frac{(X_i- \overline X)(Y_i- \overline Y)}{n+m-1} + \sum_{i=1}^m \frac{(X_i- \overline X)(Y_i- \overline Y)}{n+m-1} \\ \\ =\sum_{i=1}^n \frac{((X_i- \overline H)-(\overline X- \overline H)) * \sum_{j=1}^n ((X_j- \overline H_j)-(\overline X_j- \overline H_j))}{n+m-1} + \sum_{i=1}^m {...} \\ = \Bigl(\sum_{i=1}^n ((X_i-\overline H_i)*(\sum_{j=1}^n (X_j- \overline H_j)) + \\ \sum_{i=1}^n ((X_i- \overline H)*\sum_{j=1}^n (\overline X_j- \overline H_j) +\\ \sum_{i=1}^n (\overline X- \overline H)*\sum_{j=1}^n ((X_j- \overline H_j))+\\ \sum_{i=1}^n (\overline X- \overline H)*\sum_{j=1}^n (\overline X_j- \overline H_j) \Bigr) *\frac{1}{ n+m-1} \\ + \sum_{i=1}^m {...} \\ = \Bigl((n-1)*cov(X_n,X_n^T) + \sum_{i=1}^n ((\overline X- \overline H)*\sum_{j=1}^n (\overline X_j- \overline H_j)) \Bigr) *\frac{1}{ n+m-1} + \\ \Bigl((m-1)*cov(X_m,X_m^T) + \sum_{i=1}^m ((\overline X- \overline A)*\sum_{j=1}^m (\overline X_j- \overline A_j)) \Bigr) *\frac{1}{ n+m-1} \\
=i=1∑nn+m−1(Xi−X)(Yi−Y)+i=1∑mn+m−1(Xi−X)(Yi−Y)=i=1∑nn+m−1((Xi−H)−(X−H))∗∑j=1n((Xj−Hj)−(Xj−Hj))+i=1∑m...=(i=1∑n((Xi−Hi)∗(j=1∑n(Xj−Hj))+i=1∑n((Xi−H)∗j=1∑n(Xj−Hj)+i=1∑n(X−H)∗j=1∑n((Xj−Hj))+i=1∑n(X−H)∗j=1∑n(Xj−Hj))∗n+m−11+i=1∑m...=((n−1)∗cov(Xn,XnT)+i=1∑n((X−H)∗j=1∑n(Xj−Hj)))∗n+m−11+((m−1)∗cov(Xm,XmT)+i=1∑m((X−A)∗j=1∑m(Xj−Aj)))∗n+m−11
code
# encoding: utf-8
'''
@time: 2021/4/29 16:09
@desc:
'''
import torch
import pickle
def torch_cov(X):
# https://github.com/pytorch/pytorch/issues/19037#issuecomment-739002393
D = X.shape[-1]
mean = torch.mean(X, dim=-1).unsqueeze(-1)
X = X - mean
return 1 / (D - 1) * X @ X.transpose(-1, -2)
def get_static_data_cuda(embedding_vectors):
mean_cov = []
B, C, H, W = embedding_vectors.size()
embedding_vectors = embedding_vectors.view(B, C, H * W)
mean = torch.mean(embedding_vectors, dim=0)
cov = torch.zeros(C, C, H * W)
for i in range(H * W):
cov[:, :, i] = torch_cov(embedding_vectors[:, :, i].T)
# save learned distribution
train_outputs = [mean, cov]
return train_outputs
if __name__ == '__main__':
data = torch.rand(154,100,56,56)
print(data.shape)
data1 = data[:100]
data2 = data[100:]
m, c = get_static_data_cuda(data)
m1, c1 = get_static_data_cuda(data1)
m2, c2 = get_static_data_cuda(data2)
print()
#------------------------------------code below
n1, n2 = 100, 54
m12 = (n1 * m1 + n2 * m2) / (n1 + n2)
ch = torch.zeros_like(c)
ca = torch.zeros_like(c)
_, _, NN = c.shape
for i in range(NN):
hxx = m12[:, i:i + 1] - m1[:, i:i + 1]
axx = m12[:, i:i + 1] - m2[:, i:i + 1]
ch[:, :, i] = hxx @ hxx.T
ca[:, :, i] = axx @ axx.T
dd = ((n1 - 1) * c1 + ch + ca + (n2 - 1) * c2) / (n1 + n2 - 1)
print(torch.abs(dd - c).max())
#------------------------------------code up