官网写的太模糊了,不适合入门的新手 class LitModel(LightningModule): def __init__(self, in_dim, out_dim): super().__init__() # 保存超参数到hparams属性中 self.save_hyperparameters() self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim) # 假设训练并保存了这个模型到PATH路径 LitModel(in_dim=32, out_dim=10) # 使用in_dim=32, out_dim=10 model = LitModel.load_from_checkpoint(PATH) # 覆盖ckpt文件中的超参数,in_dim=128, out_dim=10 model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
Pytorch lightning 加载模型
最新推荐文章于 2024-06-04 10:08:46 发布