FastSpeech2的网络框架
class FastSpeech2(nn.Module):
""" FastSpeech2 """
def __init__(self, preprocess_config, model_config):
super(FastSpeech2, self).__init__()
self.model_config = model_config
self.encoder = Encoder(model_config)
self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
self.decoder = Decoder(model_config)
self.mel_linear = nn.Linear(
model_config["transformer"]["decoder_hidden"],
preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
)
self.postnet = PostNet()
加载训练好的模型
model = FastSpeech2(preprocess_config, model_config).to(device)
ckpt_path = os.path.join(
train_config["path"]["pretrained_path"],
"{}.pth.tar".format(args.restore_step),
)
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt["model"])
在该模型上加其他模块
model.add_module('encoder1', Encoder1())
model.add_module('encoder2', Encoder2())
model.add_module('encoder3', Encoder2())
...
要想使用新加载的模块,可以在FastSpeech2的forward中直接调用
a = self.encoder1(input)
b = self.encoder2(input)
c = self.encoder3(input)
...
对于预训练模型,想要固定某些层的参数可以将requires_grad设为Fasle
for k, v in ckpt["model"].items():
if k not in ['mel_linear.weight', 'mel_linear.bias']:
v.float().requires_grad = True
else:
v.float().requires_grad = False
opt = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
betas=train_config["optimizer"]["betas"],
eps=train_config["optimizer"]["eps"],
weight_decay=train_config["optimizer"]["weight_decay"],
)