seispro-fxdecon中的stft输入问题

在使用seispro中的fxdecon函数的时候,系统报错

RuntimeError: istft requires a complex-valued input tensor matching the output from stft with return_complex=True.

报错提示:带有 return_complex=False 的 stft 已被弃用。在未来的 pytorch 版本中,stft 将为所有输入返回复数张量,而 return_complex=False 将引发错误。

在通过gpt不痛不痒的改了无数次后,在github上发现几个老哥在讨论相关问题,https://github.com/asteroid-team/asteroid/issues/662

内容为:

“It seems like the newer versions of Pytorch have made some changes to the torch.stft and torch.istft functions. I've just run through the same issue and I think I could fix it by doing 'x = torch.view_as_complex(x)' just before calling torch.istft in the line that is raising the error.

Btw, you can also get rid of the deprecation warning you're getting by changing return_complex to True in the call to torch.stft and then doing stft_f = torch.view_as_real(stft_f) just after it.”

具体修改方法如下:

在seispro库里的share.py下找到def fourier_transform_time()和def inverse_fourier_transform_time(),分别修改为:

def fourier_transform_time(data, time_window_len):
    # type: (Tensor, int) -> Tensor
    """Windows and Fourier transforms the data in time.

    Inputs:
        data: A [batch_size, n_traces, n_times] shape Tensor containing the data
        time_window_len: An integer specifying the window length in the time
                         dimension to use when Fourier transforming the data.

    Returns:
        data_fx: A [batch_size, n_freqs, n_time_windows, n_traces, 2] shape
                 Tensor containing the windowed and Fourier transformed
                 data

    """
    # Use the Short-Time Fourier Transform (STFT) to window data in time and
    # Fourier transform. This requires that the data is in 2D, with time
    # in the final dimension, so we need to combine the trace and batch
    # dimensions. To facilitate later steps of the process, we then shift
    # the trace dimension.
    # [batch_size, n_traces, n_times]
    # -> [batch_size * n_traces, n_times]
    # -> [batch_size * n_traces, n_freqs, n_time_windows, 2]
    # -> [batch_size, n_traces, n_freqs, n_time_windows, 2]
    # -> [batch_size, n_freqs, n_time_windows, n_traces, 2]
    batch_size, n_traces, n_times = data.shape
    dtype = data.dtype
    device = data.device
    time_window = torch.hann_window(time_window_len, dtype=dtype, device=device)
    data_fx = torch.stft(
        data.reshape(-1, n_times),
        time_window_len,
        hop_length=time_window_len // 2,
        window=time_window, 
        #https://github.com/asteroid-team/asteroid/issues/662
        #return_complex=False,
        return_complex=True,#改为True
    )
    data_fx = torch.view_as_real(data_fx)  #修改: 将复值转换为实值表示

    n_freqs, n_time_windows = data_fx.shape[1:3]
    data_fx = data_fx.reshape(batch_size, n_traces, n_freqs, n_time_windows, 2)
    data_fx = data_fx.permute(0, 2, 3, 1, 4)
    return data_fx
def inverse_fourier_transform_time(data_fx, time_window_len, n_times):
    # type: (Tensor, int, int) -> Tensor
    """Inverse Fourier transforms in time and combines overlapping windows.

    Inputs:
        data_fx: A [batch_size, n_freqs, n_time_windows, n_traces, 2] shape
                 Tensor containing the windowed and Fourier transformed
                 data
        time_window_len: An integer specifying the window length in the time
                         dimension to use when Fourier transforming the data.
        n_times: An integer specifying the length of the original data in the
                 time dimension.

    Returns:
        data: A [batch_size, n_traces, n_times] shape Tensor containing the
              data after inverse Fourier transforming and combining windows
    """
    # [batch_size, n_freqs, n_time_windows, n_traces, 2]
    # -> [batch_size, n_traces, n_freqs, n_time_windows, 2]
    # -> [batch_size * n_traces, n_freqs, n_time_windows, 2]
    # -> [batch_size * n_traces, n_times]
    # -> [batch_size, n_traces, n_times]
    batch_size, n_freqs, n_time_windows, n_traces, _ = data_fx.shape
    dtype = data_fx.dtype
    device = data_fx.device
    time_window = torch.hann_window(time_window_len, dtype=dtype, device=device)
    data_fx = data_fx.permute(0, 3, 1, 2, 4)
    data_fx = data_fx.reshape(batch_size * n_traces, n_freqs, n_time_windows, 2)
    data_fx = torch.view_as_complex(data_fx)  # 修改:将实值表示转换为复杂值
    data = torch.istft(
        data_fx,
        time_window_len,
        hop_length=time_window_len // 2,
        window=time_window,
        length=n_times,
        return_complex=False,
    )
    return data.reshape(batch_size, n_traces, n_times)

即可解决问题。

ps.其实后面发现在报错信息里面就提到了修改方法。。。

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值