基础小波降噪方法(Python)

164 篇文章 35 订阅
46 篇文章 0 订阅

主要内容包括:

Stationary wavelet Transform (translation invariant)

Haar wavelet

Hard thresholding of detail coefficients

Universal threshold

High-pass filtering by zero-ing approximation coefficients from a 5-level decomposition of a 16Khz sampling freq.

import numpy as np
import os
from scipy.signal import detrend
from sklearn.preprocessing import MinMaxScaler
from pywt import Wavelet, threshold, wavedec, waverec
import pywt
from matplotlib import pyplot as plt
from scipy import signal

Create custom functions

def NearestEvenInteger(n):
    """! Returns the nearest even integer to number n.


    @param n Input number for which one requires the nearest even integer


    @return The even nearest integer to the input number
    """
    if n % 2 == 0:
        res = n
    else:
        res = n - 1
    return res




def std(trace, nlevel=5):
    """Estimates the standard deviation of the input trace for rescaling
    the Wavelet's coefficients.


    Returns
        Standard deviation of the input trace as (1D ndarray)
    """
    sigma = np.array([1.4825 * np.median(np.abs(trace[i])) for i in range(nlevel)])
    return sigma




def mad(x):
    """Mean absolute deviation"""
    return 1.482579 * np.median(np.abs(x - np.median(x)))




def get_universal_threshold(trace):
    num_samples = len(trace)
    sd = mad(trace)
    return sd * np.sqrt(2 * np.log(num_samples))




def get_han_threshold(trace: np.array, sigma: np.array, coeffs: np.array, nlevels: int):


    # count samples
    num_samples = len(trace)


    # han et al threshold
    details_threshs = np.array([np.nan] * len(coeffs[1:]))


    # threshold for first detail coeff d_i=0
    details_threshs[0] = sigma[1] * np.sqrt(2 * np.log(num_samples))


    # threshold from 1 < d_i < NLEVELS
    for d_i in range(1, nlevels - 1):
        details_threshs[d_i] = (sigma[d_i] * np.sqrt(2 * np.log(num_samples))) / np.log(
            d_i + 1
        )
    # threhsold for d_i = nlevels
    details_threshs[nlevels - 1] = (
        sigma[nlevels - 1] * np.sqrt(2 * np.log(num_samples))
    ) / np.sqrt(nlevels - 1)
    return details_threshs




def determine_threshold(
    trace: np.array,
    threshold: str = "han",
    sigma: np.array = None,
    coeffs: np.array = None,
    nlevels: int = None,
):
    if threshold == "universal":
        thr = get_universal_threshold(trace)
    elif threshold == "han":
        thr = get_han_threshold(trace, sigma, coeffs, nlevels)
    else:
        raise NotImplementedError("Choose an implemented threshold!")
    return thr

Load Quiroga's dataset

# download simulated dataset 01
!curl -o ../dataset/data_01.txt http://www.spikesorting.com/Data/Sites/1/download/simdata/Quiroga/01%20Example%201%20-%200-05/data_01.txt


# get project path
os.chdir("../")
proj_path = os.getcwd()
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100 37.0M  100 37.0M    0     0  26.1M      0  0:00:01  0:00:01 --:--:-- 26.1M# dataset parameters

SFREQ = 16000

nyquist = SFREQ / 2
# load dataset

trace_01 = np.loadtxt(proj_path + "/dataset/data_01.txt")

tmp_trace = trace_01.copy()




# describe

duration_secs = len(tmp_trace) / SFREQ

num_samples = len(tmp_trace)

print("duration:", duration_secs, "secs")

print("number of samples:", num_samples)




tmp_trace
duration: 90.0 secs
number of samples: 1440000

array([-0.05265172, -0.03124187, -0.00282162, ...,  0.01798155,
0.01678863,  0.0119459 ])1. Parametrize
# TODO:

# - implement Han et al., threshold

# - clear approximation threshold




WAVELET = "haar"

NLEVEL = 5

THRESH = "han"  # "universal"

THRESH_METHOD = "hard"  # 'soft'

RECON_MODE = "zero"  # 'smooth', "symmetric", "antisymmetric", "zero", "constant", "periodic", "reflect",




# calculate cutoff frequency

freq_cutoff = nyquist / 2**NLEVEL  # cutoff frequency (the max of lowest freq. band)

print("Cutoff frequency for high-pass filtering :", freq_cutoff, "Hz")
2. Preprocess# detrend

detrended = detrend(tmp_trace)




# normalize data

scaler = MinMaxScaler(feature_range=(0, 1), copy=True)

detrended = scaler.fit_transform(detrended.reshape(-1, 1))[:, 0]

normalized = detrended.copy()
3. Wavelet transform
# find the nearest even integer to input signal's length

size = NearestEvenInteger(normalized.shape[0])




# initialize filter

wavelet = Wavelet(WAVELET)




# compute Wavelet coefficients

# coeffs = wavedec(normalized[:size], wavelet, level=NLEVEL)




# translation-invariance modification of the Discrete Wavelet Transform

# that does not decimate coefficients at every transformation level.

coeffs = pywt.swt(normalized[:size], wavelet, level=NLEVEL, trim_approx=True)

coeffs_raw = coeffs.copy()




# print approximation and details coefficients

print("approximation coefficients:", coeffs_raw[0])

print("details coefficients:", coeffs_raw[1:])approximation coefficients: [2.83647772 2.84106911 2.84452235 ... 2.83109023 2.83292184 2.83465441]
details coefficients: [array([-0.00950856, -0.00603211, -0.00248562, ..., -0.01120825,
       -0.0104237 , -0.00988739]), array([-0.01276296,  0.00053794,  0.0101203 , ..., -0.03684457,
       -0.03087185, -0.02255869]), array([-0.0351023 , -0.02894606, -0.01932213, ..., -0.00506488,
       -0.019433  , -0.0299413 ]), array([-0.01386091, -0.01500663, -0.01406618, ...,  0.00959663,
        0.01430558, -0.00087641]), array([-0.00386154, -0.00512596, -0.00548882, ...,  0.00021516,
        0.00087345,  0.01160962])]4. Denoise
THRESH = "han"


# estimate the wavelet coefficients standard deviations
sigma = std(coeffs[1:], nlevel=NLEVEL)


# determine the thresholds of the coefficients per level ('universal')
# threshs = [
#     determine_threshold(
#         trace=coeffs[1 + level] / sigma[level],
#         threshold=THRESH,
#         sigma=sigma,
#         coeffs=coeffs,
#         nlevels=NLEVEL,
#     )
#     * sigma[level]
#     for level in range(NLEVEL)
# ]


# determine the thresholds of the coefficients per level ('han')
threshs = get_han_threshold(
    trace=tmp_trace,
    sigma=sigma,
    coeffs=coeffs,
    nlevels=NLEVEL,
)


# a list of 5 thresholds for "universal"
# apply the thresholds to the detail coeffs
coeffs[1:] = [
    threshold(coeff_i, value=threshs[i], mode=THRESH_METHOD)
    for i, coeff_i in enumerate(coeffs[1:])
]
# reconstruct and reverse normalize
# denoised_trace = waverec(coeffs, filter, mode=RECON_MODE)
# denoised_trace = scaler.inverse_transform(denoised_trace.reshape(-1, 1))[:, 0]


# reconstruct and reverse normalize
denoised_trace = pywt.iswt(coeffs, wavelet)
denoised_trace = scaler.inverse_transform(denoised_trace.reshape(-1, 1))[:, 0]

5. High-pass filter

# clear approximation coefficients (set to 0)
coeffs[0] = np.zeros(len(coeffs[0]))


# sanity check
assert sum(coeffs[0]) == 0, "not cleared"

6. Reconstruct trace

# reconstruct and reverse normalize
# denoised_trace = waverec(coeffs, filter, mode=RECON_MODE)
# denoised_trace = scaler.inverse_transform(denoised_trace.reshape(-1, 1))[:, 0]


# reconstruct and reverse normalize
denoised_trace = pywt.iswt(coeffs, wavelet)
denoised_trace = scaler.inverse_transform(denoised_trace.reshape(-1, 1))[:, 0]

Plot

fig = plt.figure(figsize=(10, 5))


# raw
ax = fig.add_subplot(211)
ax.plot(tmp_trace[200:800])
ax.set_title("Raw signal")


# denoised
ax = fig.add_subplot(212)
ax.plot(denoised_trace[200:800])
ax.set_title("Denoised signal")


plt.tight_layout()

Power spectrum

fs = 16e3  # 16 KHz sampling frequency


fig, axes = plt.subplots(1, 2, figsize=(10, 3))


# Welch method
freqs, powers = signal.welch(tmp_trace, fs, nperseg=1024)
axes[0].plot(freqs, powers)
axes[0].semilogy(basex=10)
axes[0].semilogx(basey=10)
axes[0].set_xlabel("frequency [Hz]")
axes[0].set_ylabel("PSD [V**2/Hz]")




# Welch method
freqs, powers = signal.welch(denoised_trace, fs, nperseg=1024)
axes[1].plot(freqs, powers)
axes[1].semilogy(basex=10)
axes[1].semilogx(basey=10)
axes[1].set_ylim([1e-12, 1e-4])
axes[1].set_xlabel("frequency [Hz]")
axes[1].set_ylabel("PSD [V**2/Hz]")


plt.tight_layout()


擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。
知乎学术咨询:https://www.zhihu.com/consult/people/792359672131756032?isMe=1
擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

哥廷根数学学派

码字不易,且行且珍惜

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值