实现在Tensorflow,Pytorch模型中嵌入傅里叶变换(STFT/ISTFT)层(类)

1. 在Tensorflow模型中嵌入傅里叶变换(STFT/ISTFT)层(类)

import tensorflow as tf
from tensorflow import  keras

'''定义傅里叶变换fft'''

class FFT(keras.layers.Layer):
    def __init__(self, win_len, **kwargs):
        super(FFT, self).__init__(**kwargs)
        self.win_len =win_len
    
    def build(self,input_shape):
        
       
        self.built = True
        
        
    def call(self, inputs):
        fft_input=tf.signal.fft(inputs)
        length=int(self.win_len/2.+1.)
        outfeature=fft_input[:,:,:,0:length]#(None, 4, 62, 257)
        return outfeature

    def get_config(self):
        config = {"win_len":self.win_len}
        base_config = super(FFT, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))




'''定义傅里叶逆变换ifft'''

class IFFT(keras.layers.Layer):
    def __init__(self, win_len, **kwargs):
        super(IFFT, self).__init__(**kwargs)
        self.win_len = win_len
    
    def build(self,input_shape):
        
       
        self.built = True
        
        
    def call(self, inputs):
        
        length=int(self.win_len/2.)
        enhanced_T=tf.reverse(inputs[:,:,:,1:length],axis=[-1])
        
        enhanced_T=tf.complex(tf.math.real(enhanced_T),-1.*tf.math.imag(enhanced_T))
        Ifft=tf.math.real(tf.signal.ifft(tf.concat([inputs,enhanced_T],-1)))
        return Ifft

    def get_config(self):
        config = {"win_len":self.win_len}
        base_config = super(IFFT, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

"实例调用"
fft=FFT(([None, 4, 9, 512]))#在最后一个维度512做短时傅里叶变换(根据需要更该)
                             #,返回ftt:([None, 4, 9, 257])

ifft=IFFT([None, 4, 9, 257])#在最后一个维度257做短时傅里叶逆变换(根据需要更该)
                               #,返回iftt:([None, 4, 9, 512])


2. 在Pytorch模型中嵌入傅里叶变换(STFT/ISTFT)层(类)


import torch
from torch import nn

class FFT(nn.Module):
    def __init__(self,win_len):
        super(FFT, self).__init__()
        self.win_len=win_len
    def forward(self, x):
        output = torch.fft.fft2(x, dim=-1)
        output=output[:,:,0:int(self.win_len/2+1)]
        return output

class IFFT(nn.Module):
    def __init__(self,win_len):
        super(IFFT, self).__init__()
        self.win_len=win_len
    def forward(self, x):
        "IFFT"
        lenth_ifft=int(self.win_len/2)
        "获取共轭数据"
        out_1=torch.flip(x[:,:,1:lenth_ifft],dims=[-1])
        
        out_2=torch.complex(out_1.real,-1.*out_1.imag)
        "拼接两个部分"
        out3=torch.cat([x,out_2],-1)
        "傅里叶逆变换"
        ioutput = (torch.fft.ifft2(out3, dim=-1)).real
        return ioutput

"测试"
length=100
t1=torch.randint(2,8,size=[4,2,length])
fft1=FFT(100)(t1)
ifft1=IFFT(100)(fft1)
print(ifft1[1,1,10:20])
print(t1[1,1,10:20])















 

 

读书,生活,旅行。 

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值