import torch
import argparse
parser = argparse.ArgumentParser("-")
parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
args = parser.parse_args()
# 载入模型和载入参数
if torch.cuda.is_available():
if opt.pre_trained_model_type == "model":
model = torch.load(opt.pre_trained_model_path)
if opt.pre_trained_model_type == "params":
model = m()
model.load_state_dict(torch.load(opt.pre_trained_model_path))
else:
if opt.pre_trained_model_type == "model":
model = torch.load(opt.pre_trained_model_path, map_location=lambda storage, loc: storage)
if opt.pre_trained_model_type == "params":
model = m()
model.load_state_dict(torch.load(opt.pre_trained_model_path, map_location=lambda storage, loc: storage))
# 保存模型和保存参数
if torch.cuda.is_available():
if opt.pre_trained_model_type == "model":
model = m()
torch.save(model,cuda(), opt.pre_trained_model_path)
if opt.pre_trained_model_type == "params":
model = m()
torch.save(model.cuda().state_dict(), opt.pre_trained_model_path)
else:
if opt.pre_trained_model_type == "model":
model = m()
torch.save(model,cpu(), opt.pre_trained_model_path)
if opt.pre_trained_model_type == "params":
model = m()
torch.save(model.cpu().state_dict(), opt.pre_trained_model_path)
【PyTorch】保存和载入模型的两种方法
于 2019-05-31 18:59:11 首次发布