我的印象是,复数乘法比实数乘法需要更长的时间,因为它需要3个乘法。
但是我尝试了以下方法:
a, b = 3, 4
c, d = 5, 6
print(a*c - b*d, a*d + b*c)
e = 3+4j
f = 5+6j
print(e*f)
%timeit a*c - b*d
%timeit a*d + b*c
%timeit e*f
%timeit a*b
并得到
-9 38
(-9+38j)
110 ns ± 0.2 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
102 ns ± 0.193 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
47.2 ns ± 0.942 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
43 ns ± 1.52 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
为什么复数乘法几乎与实数乘法一样快?
我知道Python是一种高级语言,但是这个结果仍然让我感到困惑。
我能想到的唯一解释是,两种计算都达到了如此低的速度,以某种方式来说,瓶颈不是计算本身,而是其他。
这是我想问的原始问题
我有此版本的离散傅立叶变换,它使用复杂的操作:
from math import e, pi, cos, sin, log2
def W(n, N):
theta = n * -1 * 2 * pi / N
return cos(theta) + 1j*sin(theta)
def dft(signal):
N = len(signal)
total = []
for k in range(N):
s = 0
for n in range(N):
s += signal[n] * W(n*k, N)
total.append(s)
return total
我也有这个版本的DFT,它不使用复杂的操作,而是使用两个实信号,因为信号的虚部被编码为实数组。
def Wri(n, N):
theta = n * -1 * 2 * pi / N
return cos(theta), sin(theta)
def dft_ri(signal_r, signal_i):
N = len(signal_r)
tr, ti = [], []
for k in range(N):
sr = 0
si = 0
for n in range(N):
a = signal_r[n]
b = signal_i[n]
c, d = Wri(n*k, N)
sr += a*c - b*d
si += a*d + b*c
tr.append(sr)
ti.append(si)
return tr, ti
使用以下复杂信号
x = np.linspace(0, T, 1024)
def signal2(t):
return np.where(np.logical_and((t%T>=0), (t%T
s = list(signal2(x))
sr = list(np.real(s))
si = list(np.imag(s))
两个W大致花费相同的时间
%timeit W(8, 3)
%timeit Wri(8, 3)
输出
464 ns ± 0.756 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
424 ns ± 2.21 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
然而
%timeit dft(s)
%timeit dft_ri(sr, si)
输出
935 ms ± 4.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.6 s ± 3.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
题
为什么没有复杂操作的DFT花费的时间比有复杂操作的DFT花费的时间快两倍?
解决方案
您的计时几乎是100%的开销。无论哪种方式,实际的硬件级乘法仅占成本的一小部分。
尝试使用NumPy数组,每个元素的开销要低得多,您会发现一个不同的故事:
In [1]: import numpy
In [2]: x = numpy.ones(10000)
In [3]: y = x.astype(complex)
In [4]: %timeit x*x
4.2 µs ± 51.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [5]: %timeit y*y
16.6 µs ± 696 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [6]: z = x.astype(int)
In [7]: %timeit z*z
6.54 µs ± 1 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
这是每个元素的时间在1纳秒附近,而您的计时更像是每个元素40到100。
您在FFT实现中看到的性能差异也归因于开销。实际上,您的运行时很少花在做数学上。