[Unable to load custom pretrained weight in Pytorch Lightning](https://stackoverflow.com/questions/70661251/unable-to-load-custom-pretrained-weight-in-pytorch-lightning
It can be that your .pth file is already a state_dict. Try to load pretrained weight in your lightning class.
class liteBDRAR(pl.LightningModule):
def __init__(self):
super(liteBDRAR, self).__init__()
self.model = BDRAR()
print('Model Created!')
def load_model(self, path):
self.model.load_state_dict(torch.load(path, map_location='cuda:0'), strict=False)
path = './ckpt/BDRAR/3000.pth'
model = liteBDRAR()
model.load_model(path)
上面这个方法, 程序能够运行,但是实验效果极差,说明权重加载有问题.
使用下面这个方法,实验效果就是预期的效果.
The reason why you’re getting this error is because you are trying to load your PyTorch’s model weights into the Lightning module. When saving checkpoints with Lightning you don’t only save the model states but also a bunch of other info (see here).
What you are looking for is the following:
path = './ckpt/BDRAR/3000.pth'
bdrar = liteBDRAR()
bdrar.load_state_dict(torch.load(path))