在训练我的时序预测模型的时候,想用pytorch_forecasting库中的RecurrentNetwork构建一个lstm模型,在使用pytorch_lightning库中的trtainer训练器训练我的模型时,在trainer.fit代码行报错显示TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `RecurrentNet,但是RecurrentNetwork它本身是继承了Lightnin Module的,于是我看到网上有人说pytorch_lightning库在更新后导入该模块的时候,之前的import lightning.pytorch要写成import pytorch_lightning,而RecurrentNetwork继承的AutoRegressiveBaseModelWithCovariates模型继承了Lightnin Module,也就是在basemodel.py文件中,我发现导入pytorch_lightning库的代码是import lightning.pytorch,改成import pytorch_lightning就好了
TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `RecurrentNet
最新推荐文章于 2024-07-24 09:00:00 发布