太长不看版:
ndarray变量内元素类型(dtype)为整型可能会导致运算速度变慢非常多,请注意这一点并保持矩阵运算的两变量内元素类型一致。
我有两个ndarray变量: Spike_train (shape of (492, 21474)), N (shape of (21474, 144)), 要对两者作矩阵乘法运算,调用numpy.dot函数,发现非常非常慢:
import numpy as np
np.dot(Spikes_train, N)
这个函数理论上矩阵计算速度是极快的,但是我发现当我进行一下运算时,竟然平均花费13秒~18秒的时间。令人费解。
起初以为传闻不一定对,以为是矩阵过大而导致的。于是进行模拟验证:
import numpy as np
A = np.random.uniform(0,120,(492, 21474))
B = np.random.uniform(0,120,(21474, 144))
C = np.dot(A,B)
print(type(A),type(B))
print(A.shape,B.shape,C.shape)
print(type(A),type(B),type(C))
print(A.dtype,B.dtype,C.dtype)
得到如下结果:
而且浮点型运算速度非常快!np.dot()果然名不虚传。可见我之前调用太慢不是shape的问题。
于是我查看了 Spikes_train, N 的元素类型(dtype)以及变量类型,发现了问题所在:
两个变量中元素类型虽然都是整数,但类型竟然不一样。于是进行变量的类型转换。
虽然比此前要好一些,但仍然非常慢,需要7秒左右时间。
于是将变量类型转化为浮点型。
一瞬间就除了结果,可见如同此前模拟数据一样采用的是浮点型运算会快很多。