1. 在Tensorflow模型中嵌入傅里叶变换(STFT/ISTFT)层(类)。
import tensorflow as tf
from tensorflow import keras
'''定义傅里叶变换fft'''
class FFT(keras.layers.Layer):
def __init__(self, win_len, **kwargs):
super(FFT, self).__init__(**kwargs)
self.win_len =win_len
def build(self,input_shape):
self.built = True
def call(self, inputs):
fft_input=tf.signal.fft(inputs)
length=int(self.win_len/2.+1.)
outfeature=fft_input[:,:,:,0:length]#(None, 4, 62, 257)
return outfeature
def get_config(self):
config = {"win_len":self.win_len}
base_config = super(FFT, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
'''定义傅里叶逆变换ifft'''
class IFFT(keras.layers.Layer):
def __init__(self, win_len, **kwargs):
super(IFFT, self).__init__(**kwargs)
self.win_len = win_len
def build(self,input_shape):
self.built = True
def call(self, inputs):
length=int(self.win_len/2.)
enhanced_T=tf.reverse(inputs[:,:,:,1:length],axis=[-1])
enhanced_T=tf.complex(tf.math.real(enhanced_T),-1.*tf.math.imag(enhanced_T))
Ifft=tf.math.real(tf.signal.ifft(tf.concat([inputs,enhanced_T],-1)))
return Ifft
def get_config(self):
config = {"win_len":self.win_len}
base_config = super(IFFT, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
"实例调用"
fft=FFT(([None, 4, 9, 512]))#在最后一个维度512做短时傅里叶变换(根据需要更该)
#,返回ftt:([None, 4, 9, 257])
ifft=IFFT([None, 4, 9, 257])#在最后一个维度257做短时傅里叶逆变换(根据需要更该)
#,返回iftt:([None, 4, 9, 512])
2. 在Pytorch模型中嵌入傅里叶变换(STFT/ISTFT)层(类)
import torch
from torch import nn
class FFT(nn.Module):
def __init__(self,win_len):
super(FFT, self).__init__()
self.win_len=win_len
def forward(self, x):
output = torch.fft.fft2(x, dim=-1)
output=output[:,:,0:int(self.win_len/2+1)]
return output
class IFFT(nn.Module):
def __init__(self,win_len):
super(IFFT, self).__init__()
self.win_len=win_len
def forward(self, x):
"IFFT"
lenth_ifft=int(self.win_len/2)
"获取共轭数据"
out_1=torch.flip(x[:,:,1:lenth_ifft],dims=[-1])
out_2=torch.complex(out_1.real,-1.*out_1.imag)
"拼接两个部分"
out3=torch.cat([x,out_2],-1)
"傅里叶逆变换"
ioutput = (torch.fft.ifft2(out3, dim=-1)).real
return ioutput
"测试"
length=100
t1=torch.randint(2,8,size=[4,2,length])
fft1=FFT(100)(t1)
ifft1=IFFT(100)(fft1)
print(ifft1[1,1,10:20])
print(t1[1,1,10:20])
读书,生活,旅行。