加载预训练模块,并新增模块(以Fast Speech2为例)

FastSpeech2的网络框架

class FastSpeech2(nn.Module):
    """ FastSpeech2 """

    def __init__(self, preprocess_config, model_config):
        super(FastSpeech2, self).__init__()
        self.model_config = model_config

        self.encoder = Encoder(model_config)
        self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
        self.decoder = Decoder(model_config)
        self.mel_linear = nn.Linear(
            model_config["transformer"]["decoder_hidden"],
            preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
        )
        self.postnet = PostNet()

加载训练好的模型

model = FastSpeech2(preprocess_config, model_config).to(device)
ckpt_path = os.path.join(
            train_config["path"]["pretrained_path"],
            "{}.pth.tar".format(args.restore_step),
        )
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt["model"])

在该模型上加其他模块

model.add_module('encoder1', Encoder1())
model.add_module('encoder2', Encoder2())
model.add_module('encoder3', Encoder2())
...

要想使用新加载的模块,可以在FastSpeech2的forward中直接调用

a = self.encoder1(input)
b = self.encoder2(input)
c = self.encoder3(input)
...

对于预训练模型,想要固定某些层的参数可以将requires_grad设为Fasle

for k, v in ckpt["model"].items():
    if k not in ['mel_linear.weight', 'mel_linear.bias']:
        v.float().requires_grad = True
    else:
        v.float().requires_grad = False
opt = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            betas=train_config["optimizer"]["betas"],
            eps=train_config["optimizer"]["eps"],
            weight_decay=train_config["optimizer"]["weight_decay"],
        )

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值