一、背景介绍
当我们计算矩阵乘法
其中,,并且假设
当n很大的时候,直接计算计算量会非常的大,因此我们引入随机矩阵乘法(Rangomized matrix multiplication)
二、Randomized Matrix Multiplication
对于任意的概率分布, 其中
根据概率随机的从中选 K 列,从 中选择第 列并记为 ,从中选择第 行并记作
相应的定义:
之后计算:
三、计算方法可行性
下面我们说明上面这种计算方式是可行的。
1、说明这种估计方法是无偏估计
我们定义一个估计量:
(其中定义如上,并且是i.i.d.的)
计算这个估计量的期望:
对上式求和中的每一项:
需要注意这里的 是对所有可能的求期望,也就是可能取到的列(行)的所
有可能性,因此上式可以继续写成:
进而:
因此 是 的无偏估计,这也就说明了这种方法的可行性。
四、方法准确性(Accutacy)
1、方差分析
计算 的方差:
因为, 所以我们认为,均是有界的
这样就有:
结合上面的(2)式
又因为,所以:
结合(2)式我们有:
所以
因为前面已经说明了 是一个无偏估计,因此估计量的方差随着样本量的增加以某个速率减少,那么这个速率就是收敛阶。事实上,这种方法的收敛速度是一阶的,因为:
2、编程验证
import numpy as np
import matplotlib.pyplot as plt
def random_matrix_multiplication(A, B, K,q,p):
m, n = A.shape # m 是行数,n 是列数
n, q = B.shape # p_size 是矩阵 B 的列数
C_hat = np.zeros((m, q))
for _ in range(K):
i = np.random.choice(range(n), p=p) # 选择 A 的列索引
L = A[:, i] / np.sqrt(K * p[i]) # 调整 A 的列
R = B[i, :] / np.sqrt(K * p[i]) # 调整 B 的行
C_hat += np.outer(L, R) # 累加到结果矩阵
return C_hat
def true_matrix_multiplication(A, B):
return np.dot(A, B)
np.random.seed(0)
# 生成随机矩阵 A 和 B
m, n, q = 100, 10000, 100
A = np.random.randn(m, n)
B = np.random.randn(n, q)
# 定义概率分布
p = np.random.dirichlet(np.ones(n), size=1).flatten()
# 计算随机矩阵乘法估计
K_values = [1, 10,100,1000,10000,100000]
errors = []
distances = []
for K in K_values:
C_hat = random_matrix_multiplication(A, B, K, q, p)
C_true = true_matrix_multiplication(A, B)
error = np.linalg.norm(C_true - C_hat, 'fro') # 计算误差
errors.append(error)
distance = np.abs(error - 1/K)
distances.append(distance)
# 输出结果
for K, error in zip(K_values, errors):
print(f"K = {K}, Error = {error}")
# 计算理论的 1/K 曲线
theoretical_errors = [1/K for K in K_values]
# 创建图表
plt.figure()
plt.loglog(K_values, errors, marker='o', label="RMM Error")
plt.loglog(K_values, theoretical_errors, linestyle='--', label="1/K")
# 添加图表元素
plt.xlabel("Number of samples (K)")
plt.ylabel("Error (Frobenius Norm)")
plt.legend()
plt.title("Error Convergence Comparison")
plt.grid(True)
plt.show()
# 绘制 distance 随样本数的变化
plt.figure()
plt.loglog(K_values, distances, marker='o', color='red', label="Distance to 1/K")
plt.xlabel("Number of samples (K)")
plt.ylabel("Distance")
plt.legend()
plt.title("Distance between RMM Error and 1/K")
plt.grid(True)
plt.show()
尽管误差很大,但是可以看出收敛速度和1/K成正比关系:
K = 1, Error = 593939.877764057
K = 10, Error = 302189.4477425598
K = 100, Error = 168372.75693264598
K = 1000, Error = 63433.9901448806
K = 10000, Error = 17227.919149845973
K = 100000, Error = 6214.796966578531
也考虑过如此巨大的误差是不是因为概率选用的是Dirichlet分布,后续改成了均匀分布尝试了一下:
import numpy as np
import matplotlib.pyplot as plt
def random_matrix_multiplication(A, B, K,q,p):
m, n = A.shape # m 是行数,n 是列数
n, q = B.shape # p_size 是矩阵 B 的列数
C_hat = np.zeros((m, q))
for _ in range(K):
i = np.random.choice(range(n), p=p) # 选择 A 的列索引
L = A[:, i] / np.sqrt(K * p[i]) # 调整 A 的列
R = B[i, :] / np.sqrt(K * p[i]) # 调整 B 的行
C_hat += np.outer(L, R) # 累加到结果矩阵
return C_hat
def true_matrix_multiplication(A, B):
return np.dot(A, B)
np.random.seed(0)
# 生成随机矩阵 A 和 B
m, n, q = 100, 10000, 100
A = np.random.randn(m, n)
B = np.random.randn(n, q)
# 定义概率分布
p = p = np.ones(n) / n
# 计算随机矩阵乘法估计
K_values = [1, 10,100,1000,10000,100000]
errors = []
distances = []
for K in K_values:
C_hat = random_matrix_multiplication(A, B, K, q, p)
C_true = true_matrix_multiplication(A, B)
error = np.linalg.norm(C_true - C_hat, 'fro') # 计算误差
errors.append(error)
distance = np.abs(error - 1/K)
distances.append(distance)
# 输出结果
for K, error in zip(K_values, errors):
print(f"K = {K}, Error = {error}")
# 计算理论的 1/K 曲线
theoretical_errors = [1/K for K in K_values]
# 创建图表
plt.figure()
plt.loglog(K_values, errors, marker='o', label="RMM Error")
plt.loglog(K_values, theoretical_errors, linestyle='--', label="1/K")
# 添加图表元素
plt.xlabel("Number of samples (K)")
plt.ylabel("Error (Frobenius Norm)")
plt.legend()
plt.title("Error Convergence Comparison")
plt.grid(True)
plt.show()
# 绘制 distance 随样本数的变化
plt.figure()
plt.loglog(K_values, distances, marker='o', color='red', label="Distance to 1/K")
plt.xlabel("Number of samples (K)")
plt.ylabel("Distance")
plt.legend()
plt.title("Distance between RMM Error and 1/K")
plt.grid(True)
plt.show()
发现效果提升挺多的:
K = 1, Error = 960299.8466397041
K = 10, Error = 313171.49510340684
K = 100, Error = 100144.58623160765
K = 1000, Error = 31705.07474778639
K = 10000, Error = 9933.407925596011
K = 100000, Error = 3208.1525021334205
如有不严谨和错误的地方欢迎各位大佬指正!