使用fft对相位进行unwrap
如果复数相位是另一个量,例如频率,的线性函数,而我们想对相位进行直线拟合,那么我们首先需要对相位进行unwrap
,然后才能进行拟合。numpy.unwrap
可以对角度进行unwrap,但是如果有大段的数据缺失,numpy.unwrap
会失效,我利用fft
实现了一种新的unwrap的方法,如下:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
def fft_unwrap(freq, phs, pk_th=0.0, unwrap_th=1.5*np.pi, precision=1.e-3):
'''
Unwrap phs assuming phs is linear function wrt freq.
freq: array_like, phs = a*freq + b
phs: array_like, in rad, phase to unwrap
unwrap_th: float, if the phase range is smaller than 2*pi, fft will break down, and we use numpy.unwrap instead
precision: float, we will padding 0 to make the k after fft have interval smaller than precision.
return: ndarray, phs after unwrap
'''
# phs in rad
isort = np.argsort(freq)
freq = np.array(freq)[isort]
phs = np.array(phs)[isort]
dfreq = freq[1]-freq[0]
y = np.exp(1.J*phs)
n = 1./(precision*dfreq)
n = int(2**np.ceil(np.log2(n)))
if n < y.shape[0]:
n = y.shape[0]
valid = np.isfinite(y)
y0 = y[valid][0]
phs0 = phs[valid][0]
freq0 = freq[valid][0]
y = y/y0 # normalize
y[~valid] = 0.
y = y - y.mean()
fft = np.fft.fft(y, n=n)
fftfreq = np.fft.fftfreq(fft.shape[0], d=dfreq)
fft = np.abs( np.fft.fftshift(fft) )
fftfreq = np.fft.fftshift(fftfreq)*2*np.pi
#plt.figure()
#plt.plot(fftfreq, fft)
ipk, pk_dict = find_peaks(fft, height=0.)
pk_heights = pk_dict['peak_heights']
isort = np.argsort(pk_heights)
pk_heights = pk_heights[isort]
ipk = ipk[isort]
#print(fftfreq[ipk])
#print('highest peak: %s'%pk_heights[-1])
unwrap_phs = np.unwrap(phs[valid])
print('second highest peak: %s (%.2f of highest)'%(pk_heights[-2], pk_heights[-2]/pk_heights[-1]) )
print('maximum phase difference after np.unwrap: %s'%(unwrap_phs.max()-unwrap_phs.min()))
if pk_heights[-2]>pk_heights[-1]*pk_th and unwrap_phs.max()-unwrap_phs.min()<unwrap_th:
# all in one period
#plt.figure()
#plt.plot(freq, phs, 'o', label='data')
phs[valid] = np.unwrap(phs[valid])
#plt.plot(freq, phs, '.', label='unwrap')
#plt.legend()
#plt.show()
return phs
else:
k = fftfreq[ipk[-1]]
phs_pred = phs0 + k*(freq-freq0)
dphs = phs_pred - phs
nphs = np.around(dphs/2./np.pi)
#plt.figure()
#plt.plot(freq, phs, 'o', label='data')
phs = phs + nphs*2*np.pi
#plt.plot(freq, phs, '.', label='unwrap')
#plt.plot(freq, phs_pred, '-', label='pred')
#plt.legend()
#plt.show()
return phs
if __name__ == '__main__':
for ii in range(10):
x = np.linspace(0., 1.8*np.pi, 201) + (np.random.rand()-0.5)*2*np.pi
w = 1.
ns = ( np.random.rand(*x.shape) - 0.5 )*1.5
phs0 = w*x
#mask = int(x.shape[0]/1.5)
#inds = np.random.choice(np.arange(x.shape[0]), replace=False, size=mask)
inds1 = np.arange(30, 50)
inds2 = np.arange(60, 80)
inds3 = np.arange(100,120)
inds = np.concatenate([inds1, inds2, inds3])
phs0[inds] = np.nan
y = np.exp(1.J*(phs0 + ns))
phs = np.angle(y)
phs1 = fft_unwrap(x, phs)
plt.plot(x, phs1, 'o', label='unwrap')
plt.plot(x, phs, '*', label='phs')
plt.plot(x, w*x, '.', label='phs0')
plt.legend()
plt.show()