模型的保存
evla_RMSE_log = torch.sqrt(eval_loss/count)
if max(best) <= evla_RMSE_log:
best.append(evla_RMSE_log)
torch.save(ResNet.state_dict(), './saved_models/best_model_{}.pth'.format(epoch))
这里边以 evla_RMSE_log 作为模型好坏的一个标准,以此不断的更替
模型加载
model.load_state_dict(torch.load("./saved_models/best_model.pth"))
torch.cuda.empty_cache()
简洁明了,第二段为清理显存缓存