之前也写过该问日的解决方法,但是学长在调试代码的时候问我的问题让我决定完善这个问题的答复。
检查自己的运行文件中import的写法
我的运行文件名为train.py文件,如下显示为我的import写法:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import wandb
import yaml
import os
import argparse
import torch.nn as nn
import random
from config import Cfg
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.data.encoders import EncoderNormalizer, TorchNormalizer
from pytorch_forecasting.metrics import QuantileLoss, SMAPE, MAE, RMSE, MAPE, NormalDistributionLoss
from pytorch_lightning.callbacks import ModelCheckpoint
检查base_model.py文件中的import写法
如下是base_model.py文件中的import写法
form lightning.pytorch import LightningModule,Trainer
在base_model.py文件中我们只需要关注这一句就可以了。
如何修改
两者区别可以看出:我们自己的运行文件中使用的是import pytorch_lightning as pl导入pytorch lightning库,而base_model.py文件中使用的是form lightning.pytorch import LightningModule,Trainer 导入pytorch lightning库,所以将自己的运行文件中所有pytorch_lightning替换为lightning.pytorch,代码也是可以运行的。
修改base_model.py文件中的import写法
以我的代码为例,我们也可以在base_model.py文件中将form lightning.pytorch import LightningModule,Trainer修改为form pytorch_lightning import LightningModule,Trainer,但是在代码后续运行时会在
trainer.test(
model=best_tft,
dataloaders=test_dataloader,
)
这一步继续报错显示model must be a LightingModule or torch._dynamo.OptimizedModule got TemporalFusionTransformer,这时候为了得到数据集的测试结果我通常会将这一部分与训练部分分开为两个py文件进行运行,但是这样需要我们在切换两个文件的同时也修改base_model.py文件中的import写法,过于繁琐。