当模型被封装在 nn.DataParallel 中时需要注意的
当模型被封装在 nn.DataParallel 中时,我们需要使用模型的 module 属性来获得真正的模型状态字典。(model.module.state_dict())
torch.save(
{
"model": model.module.state_dict(),
"optimizer": optimizer._optimizer.state_dict(),
},
os.path.join(
train_config["path"]["ckpt_path"],
"{}.pth.tar".format(step),
),
)
获取某一训练好的模型的输出的自定义方法
def get_speakermodel(mel):
model = MultiTaskModel().to(device)
ckpt_path = "/home/nicola/LA_SE/output/ckpt/300000.pth.tar"
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["model"])
with torch.no_grad():
s_id = model(mel)
return s_id
微调
model = FastSpeech2(preprocess_config, model_config).to(device)
if train:
ckpt_path = os.path.join(
train_config["path"]["ckpt_path"],
"{}.pth.tar".format(args.restore_step),
)
ckpt = torch.load(ckpt_path)
scheduled_optim = ScheduledOptim(
model, train_config, model_config, args.restore_step
)
if args.restore_step: #将这一句注释掉或者设置为True就是微调了
scheduled_optim.load_state_dict(ckpt["optimizer"])
model.train()
return model, scheduled_optim