小波变换是一种特殊类型的数学变换,通俗来说,是用有限波的平移和缩放表示信号,就其应用而言,离散小波变换DWT常用于信号编码(典型的JPEG2000格式),而连续小波变换一般用于信号分析。小波的缺点之一是必须事先选择要使用的母小波,例如轴承故障检测中经常使用的Morlet小波,适用于地震信号处理的Ricker小波,适用模态分析的Laplace小波等等,这些一定程度上限制了小波的应用范围。
相关参考见知乎上的如下文章:
形象易懂讲解算法I——小波变换 - 咚懂咚懂咚的文章 - 知乎 https://zhuanlan.zhihu.com/p/22450818
因此,有这样一种想法,即开发一种机器学习或深度学习模型,该模型能够在给定信号的情况下找到正交小波的滤波器 h 和 g,找到的这些滤波器应该足够“好”,以便在对信号应用小波变换时能保留最多的信息。总体思路是建立一种自编码器,其参数是滤波器 h 和 g,对信号X进行小波变换(滤波器 h 和 g), 逆变换反过来。将损失函数通过梯度下降算法迭代优化 h 和 g 的系数, 然后使用新的系数对信号 X 重复第一步的过程。当满足某些停止条件时,过程结束。其实也比较容易理解,本质就是梯度下降算法迭代改进滤波器 h 和 g的系数。通过损失函数模型学习到的滤波器必须满足如下属性:
模型参数将通过梯度下降来学习,因此必须转化为可微的损失函数:
并增惩罚项
开始进入正题,首先导入相关模块,没有pywt小波模块的首先要pip install pywt
import pywt
import numpy as np
from sklearn.metrics import mean_squared_error
import math
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
然后定义编码器解码器模块
class dwt:
def __init__(self, filter_size, h = None, g = None):
if h is None:
self.h = np.random.normal(1,2,filter_size)
else:
self.h = h
if g is None:
self.g = np.random.normal(1,2,filter_size)
else:
self.g = g
##编码器模块还在进一步优化中,稍后发布
class idwt:
def __init__(self, dwt):
self.h = dwt.h
self.g = dwt.g
##解码器模块还在进一步优化中,稍后发布
定义代价函数
def loss_function(h, g, x, x_, W, reconstruction_weight = 1):
l1 = 0.5
l2 = 0.5
r1 = reconstruction_weight
s1 = mean_squared_error(x, x_)
s2 = np.sum(np.abs(W))
s3 = lhg(h,g)
return r1*s1 + l1*s2 + l1*s3
def lhg(h, g):
s1 = (np.sum(h**2) - 1)**2
s2 = (np.sum(h) - math.sqrt(2))**2
s3 = np.sum(g)**2
return s1 + s2 + s3
计算梯度
def compute_gradient_h(h, g, x, loss_function, diff = 0.001):
length = len(h)
gradient = np.zeros(length)
mask = np.zeros(length)
for i in range(length):
mask[i] = diff
h_plus = h + mask
h_minus = h - mask
dwtp = dwt(1, h_plus, g)
idwtp = dwtp.inverse()
dwtm = dwt(1, h_minus, g)
idwtm = dwtm.inverse()
Wp = dwtp.compute(x)
x_p = idwtp.compute(Wp)
Wm = dwtm.compute(x)
x_m = idwtm.compute(Wm)
mask[i] = 0
lossp = loss_function(h_plus, g, x, x_p, Wp)
lossm = loss_function(h_minus, g, x, x_m, Wm)
gradient[i] = (lossp - lossm)/(2*diff)
return gradient
def compute_gradient_g(h, g, x, loss_function, diff = 0.001):
length = len(g)
gradient = np.zeros(length)
mask = np.zeros(length)
for i in range(length):
mask[i] = diff
g_plus = g + mask
g_minus = g - mask
dwtp = dwt(1, h, g_plus)
idwtp = dwtp.inverse()
dwtm = dwt(1, h, g_minus)
idwtm = dwtm.inverse()
Wp = dwtp.compute(x)
x_p = idwtp.compute(Wp)
Wm = dwtm.compute(x)
x_m = idwtm.compute(Wm)
mask[i] = 0
lossp = loss_function(h, g_plus, x, x_p, Wp)
lossm = loss_function(h, g_minus, x, x_m, Wm)
gradient[i] = (lossp - lossm)/(2*diff)
return gradient
def compute_gradient(h, g, x, loss_function, diff = 0.001):
hg = compute_gradient_h(h, g, x, loss_function, diff)
gg = compute_gradient_g(h, g, x, loss_function, diff)
return (hg, gg)
定义模型类
class Model:
def __init__(self,filter_size = 2**5):
self.my_dwt = dwt(filter_size)
self.my_idwt = self.my_dwt.inverse()
self.losses = []
self.min_loss = math.inf
def fit(self, x, epochs = 100, learning_rate = 0.001, verbose = True, good_error = None, reconstruction_weight = 1, diff = 0.000001):
best_h = self.my_dwt.h
best_g = self.my_idwt.g
for i in range(epochs):
W = self.my_dwt.compute(x)
x_ = self.my_idwt.compute(W)
loss = loss_function(self.my_dwt.h, self.my_dwt.g, x, x_, W, reconstruction_weight)
self.losses.append(loss)
if loss < self.min_loss:
self.min_loss = loss
best_h = self.my_dwt.h
best_g = self.my_dwt.g
if verbose:
print('Epochs #' + str(i+1) + ": " + str(loss) + " loss")
if not good_error is None and loss <= good_error:
return
hg, gg = compute_gradient(self.my_dwt.h, self.my_dwt.g, x, loss_function, diff)
self.my_dwt.update_weigths(hg, gg, learning_rate)
self.my_dwt = dwt(1, best_h, best_g)
self.my_idwt = self.my_dwt.inverse()
print("Best Loss", self.min_loss)
def predict(self, x):
W = self.my_dwt.compute(x)
return self.my_idwt.compute(W)
def wavelet(self):
filter_bank = [self.my_dwt.h, self.my_dwt.g, np.flip(self.my_dwt.h), np.flip(self.my_dwt.g)]
my_wavelet = pywt.Wavelet(name="my_wavelet", filter_bank=filter_bank)
return my_wavelet
def dwt(self, x):
ca, cd = self.my_dwt.compute(x)
return np.concatenate([cd, ca])
定义信号生成函数
def generate_wave_coeff(length):
result = []
for _ in range(length*2):
if np.random.random() < 0.9:
result.append(0.0)
else:
result.append(np.random.uniform(-1,1))
return (result[:length], result[length:])
def generate_signal(length, familie):
ca, cd = generate_wave_coeff(length)
x = pywt.idwt(ca, cd, familie)
return x
定义滤波器之间的距离
def dist(f1, f2):
max_l = max(len(f1), len(f2))
min_l = min(len(f1), len(f2))
diff = max_l - min_l
best_i = 0
if len(f1) == min_l:
f1 = np.concatenate([f1, np.zeros(diff)])
if len(f2) == min_l:
f2 = np.concatenate([f2, np.zeros(diff)])
distance = math.inf
f1_norm2 = np.sqrt(np.sum(f1**2))
f2_norm2 = np.sqrt(np.sum(f2**2))
for i in range(max_l):
current = 1 - (f1.dot(np.roll(f2,i))/(f1_norm2 * f2_norm2))
if current < distance:
distance = current
best_i = i
f2 = np.roll(f2, best_i)
return distance
pywt模块的内置小波函数如下
'''Haar (haar)
Daubechies (db)
Symlets (sym)
Coiflets (coif)
Biorthogonal (bior)
Reverse biorthogonal (rbio)
“Discrete” FIR approximation of Meyer wavelet (dmey)
Gaussian wavelets (gaus)
Mexican hat wavelet (mexh)
Morlet wavelet (morl)
Complex Gaussian wavelets (cgau)
Shannon wavelets (shan)
Frequency B-Spline wavelets (fbsp)
Complex Morlet wavelets (cmor)'''
选择一个正交小波进行测试
w = 'haar'
wavelet = pywt.Wavelet(w)
print("is orthogonal?",wavelet.orthogonal)
wavelet.dec_len
生成信号并进行模型拟合
x = generate_signal(32, w)
model = Model(2)
model.fit(x, epochs=2000, learning_rate=0.001, good_error=4, reconstruction_weight=10)
重建误差
p = model.predict(x)
mse = mean_squared_error(p,x)
plt.plot(np.arange(len(x)), x, c = 'r', label="original")
plt.plot(np.arange(len(p)), p, c = 'b', label="reconstruction")
plt.legend()
plt.title("Signals -->" + " mse: " + str(mse))
print()
小波函数和尺度函数误差
[phi_d,psi_d,phi_r,psi_r,al] = model.wavelet().wavefun(level=1)
[phi, psi, ao] = pywt.Wavelet(w).wavefun(level=1)
lo = np.arange(len(ao))
plt.figure(1)
plt.plot(ao, psi, c = 'r', label="original")
plt.plot(al,psi_d, c='b', label='learned')
plt.title("Wavelets --> distance:" + str(dist(psi, psi_d)))
plt.legend()
plt.figure(2)
plt.title("Scaling --> distance: " + str(dist(phi, phi_d)))
plt.plot(ao, phi, c = 'r', label="original")
plt.plot(al,phi_d, c='b', label='learned')
plt.legend()
#学习到的滤波器系数与真是滤波器系数
fake_dwt = model.dwt(x)
ca, cd = pywt.dwt(x, wavelet=w)
real_dwt = np.concatenate([cd, ca])
fake_len = len(fake_dwt)
real_len = len(real_dwt)
min_len = min(fake_len, real_len)
real_dwt = real_dwt[:min_len]
fake_dwt = fake_dwt[:min_len]
mse_dwt = mean_squared_error(real_dwt, fake_dwt)
plt.plot(np.arange(min_len), real_dwt, c = 'r', label = 'real')
plt.plot(np.arange(min_len), fake_dwt, c = 'b', label = 'learned')
plt.title("DWT --> " + str(mse_dwt) + " mse")
plt.legend()
print()
本文详细的代码见如下链接