之前跑了好几个模型,都需要保存和加载模型信息。为了省得每次都上网找,开一个帖记录一下。 1. 保存模型 此处只保存模型参数 model = torch.nn.Linear(1, 2) torch.save(model.state_dict(), "./model.pth") 2. 加载模型 model = torch.nn.Linear(1, 2) model.load_state_dict(torch.load("./model.pth"))