下载项目
WNO在github上面的项目地址如下:
https://github.com/csccm-iitd/WNO/tree/main
下载下来后,里面的数据集需要用matlab代码生成,也可以到里面提到的google云盘里面下载数据集
安装环境
然后需要安装环境
注意python版本一定要大于等于3.8,否则会安装失败。并且这个需要torch环境,我用的是torch 1.10.0(CPU版本)和2.0.0+cu117(GPU版本)
运行代码会报错如下,提示安装下面的三个环境
Wavelet convolution requires <Pytorch Wavelets>, <PyWavelets>, <Pytorch Wavelet Toolbox> \n \
For Pytorch Wavelet Toolbox: $ pip install ptwt \n \
For PyWavelets: $ conda install pywavelets \n \
For Pytorch Wavelets: $ git clone https://github.com/fbcotter/pytorch_wavelets \n \
$ cd pytorch_wavelets \n \
$ pip install .')
注意:pytorch_wavelets这个文件最好和要运行的WNO相关的python文件在同一个目录,不然可能会报错。
问题0:
如果报错的话,可以运行一下下面三段程序,看哪个没有安装成功,解决报错问题后再次安装成功即可。
# 1.验证pytorch_wavelets是否安装成功
import torch
from pytorch_wavelets import DWT1DForward, DWT1DInverse # or simply DWT1D, IDWT1D
dwt = DWT1DForward(wave='db6', J=3)
X = torch.randn(10, 5, 100)
yl, yh = dwt(X)
print(yl.shape)
print(yh[0].shape)
print(yh[1].shape)
print(yh[2].shape)
idwt = DWT1DInverse(wave='db6')
x = idwt((yl, yh))
#2.验证pywt
import pywt
cA, cD = pywt.dwt([1, 2, 3, 4], 'db1')
#3.验证ptwt
import torch
import numpy as np
import pywt
import ptwt # use "from src import ptwt" for a cloned the repo
# generate an input of even length.
data = np.array([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0])
data_torch = torch.from_numpy(data.astype(np.float32))
wavelet = pywt.Wavelet('haar')
如果输出下面这个,并且没有报错信息,说明没有问题,安装成功。:
torch.Size([10, 5, 22])
torch.Size([10, 5, 55])
torch.Size([10, 5, 33])
torch.Size([10, 5, 22])
问题1:
报错内容是:找不到DWT1D这个模型
将pytorch_wavelets放到代码同目录然后再pip install .安装后即可解决这个问题
问题2:
AttributeError: module 'dill' has no attribute 'extend'
解决方法:
这个错误表明 dill
模块没有 extend
属性。这个问题可能是由于 dill
模块的版本不兼容导致的。更新 dill
模块: 首先尝试更新 dill
模块到最新版本,这通常可以解决版本兼容问题。
pip install --upgrade dill