复现Your TTS使用coqui tts时计算SCL损失的梯度问题

Your TTS(https://arxiv.org/pdf/2112.02418v3.pdf)是基于VITS的多说话人多语言TTS,大部分的内容和VITS非常相近,大部分为实验部分。

其中比较有意思的是增加了scl损失。

损失也并不难理解,就是【生成出来的语音】和【原始语音】过一个speaker encoder出来的speaker embedding的余弦相似度。根据论文的说法,这个speaker encoder应该是pre-train出来的,在计算SCL损失的过程中,encoder的参数应该是固定不动的。也符合我们比较朴素的理解,如果encoder的参数可以训练,那么无论输入是什么,encoder的输出都趋近于一致,这个和我们的预期不符。

coqui复现了Your TTS(在git上),整体代码我就不放了,重点看一下scl损失的实现。

首先看TTS/tts/model/vits.py,里面的init_multispeaker

    def init_multispeaker(self, config: Coqpit):
        """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
        or with external `d_vectors` computed from a speaker encoder model.

        You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.

        Args:
            config (Coqpit): Model configuration.
            data (List, optional): Dataset items to infer number of speakers. Defaults to None.
        """
        self.embedded_speaker_dim = 0
        self.num_speakers = self.args.num_speakers
        self.audio_transform = None

        if self.speaker_manager:
            self.num_speakers = self.speaker_manager.num_speakers

        if self.args.use_speaker_embedding:
            self._init_speaker_embedding()

        if self.args.use_d_vector_file:
            self._init_d_vector()

        # TODO: make this a function
        if self.args.use_speaker_encoder_as_loss:
            if self.speaker_manager.speaker_encoder is None and (
                not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
            ):
                raise RuntimeError(
                    " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
                )

            self.speaker_manager.speaker_encoder.eval()
            print(" > External Speaker Encoder Loaded !!")

            if (
                hasattr(self.speaker_manager.speaker_encoder, "audio_config")
                and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
            ):
                self.audio_transform = torchaudio.transforms.Resample(
                    orig_freq=self.audio_config["sample_rate"],
                    new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
                )
            # pylint: disable=W0101,W0105
            self.audio_transform = torchaudio.transforms.Resample(
                orig_freq=self.config.audio.sample_rate,
                new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
            )

里面的

if self.args.use_speaker_encoder_as_loss:
            if self.speaker_manager.speaker_encoder is None and (
                not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
            ):
                raise RuntimeError(
                    " [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
                )

            self.speaker_manager.speaker_encoder.eval()
            print(" > External Speaker Encoder Loaded !!")

这个地方保证了speaker encoder模型的加载。

在第901行,也就是forward函数中我们看到了调用。

if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
            # concate generated and GT waveforms
            wavs_batch = torch.cat((wav_seg, o), dim=0)

            # resample audio to speaker encoder sample_rate
            # pylint: disable=W0105
            if self.audio_transform is not None:
                wavs_batch = self.audio_transform(wavs_batch)

            pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)

            # split generated and GT speaker embeddings
            gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
        else:
            gt_spk_emb, syn_spk_emb = None, None

进行cat、重采样、过encoder网络、分开gt emb和syn emb,还是很正常的流程

后面计算损失:

    def cosine_similarity_loss(gt_spk_emb, syn_spk_emb):
        return -torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean()

看起来一切都很正常,但是我突然发现这一句有些出乎我的意料

pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)

为什么会是speaker_encoder.forward呢?

不妨看一下speaker_encoder

speaker encoder 代码里有2种实现,一种是resnet,另一种是lstm,但是大同小异,我们就先看一种。

def forward(self, x, l2_norm=True):
        """Forward pass of the model.

        Args:
            x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
                to compute the spectrogram on-the-fly.
            l2_norm (bool): Whether to L2-normalize the outputs.

        Shapes:
            - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
        """
        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=False):
                if self.use_torch_spec:
                    x.squeeze_(1)
                    x = self.torch_spec(x)
                x = self.instancenorm(x).transpose(1, 2)
        d = self.layers(x)
        if self.use_lstm_with_projection:
            d = d[:, -1]
        if l2_norm:
            d = torch.nn.functional.normalize(d, p=2, dim=1)
        return d

    @torch.no_grad()
    def inference(self, x, l2_norm=True):
        d = self.forward(x, l2_norm=l2_norm)
        return d

    def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
        """
        Generate embeddings for a batch of utterances
        x: 1xTxD
        """
        max_len = x.shape[1]

        if max_len < num_frames:
            num_frames = max_len

        offsets = np.linspace(0, max_len - num_frames, num=num_eval)

        frames_batch = []
        for offset in offsets:
            offset = int(offset)
            end_offset = int(offset + num_frames)
            frames = x[:, offset:end_offset]
            frames_batch.append(frames)

        frames_batch = torch.cat(frames_batch, dim=0)
        embeddings = self.inference(frames_batch)

        if return_mean:
            embeddings = torch.mean(embeddings, dim=0, keepdim=True)

        return embeddings

我们看到有compute embedding这个函数是计算一句语音的embedding的,其中也是调用了self.inference。而self.inference 是通过@torch.no_grad()截断梯度。

这些都比较符合认知。

可是,为什么vits里面用了forward呢?不担心更新speaker encoder的梯度吗?

首先我先想到了是不是在opt里面把speaker encoder的参数删了呢?

    def get_optimizer(self) -> List:
        """Initiate and return the GAN optimizers based on the config parameters.
        It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
        Returns:
            List: optimizers.
        """
        # select generator parameters
        optimizer0 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)

        gen_parameters = chain(params for k, params in self.named_parameters() if not k.startswith("disc."))
        optimizer1 = get_optimizer(
            self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
        )
        return [optimizer0, optimizer1]

很遗憾,并没有。我们看到opt0里面是所有参数,opt1里面是G的参数

那会不会是和更新D一样通过detach的?还是在vits.py的1055行

            # generator pass
            outputs = self.forward(
                tokens,
                token_lenghts,
                spec,
                spec_lens,
                waveform,
                aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids, "language_ids": language_ids},
            )

            # cache tensors for the generator pass
            self.model_outputs_cache = outputs  # pylint: disable=attribute-defined-outside-init

            # compute scores and features
            scores_disc_fake, _, scores_disc_real, _ = self.disc(
                outputs["model_outputs"].detach(), outputs["waveform_seg"]
            )

这里也只是训练D的时候对G的输出进行了detach。

难道真的是在训练speaker encoder吗?

重新看一下

pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)

speaker_encoder在speaker_manager里,那speaker_manager是什么呢?

在TTS/tts/utils/speakers.py里,为了方便阅读我删掉了注释

class SpeakerManager:
    """Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
    in a way that can be queried by speaker or clip.

    """
    def __init__(
        self,
        data_items: List[List[Any]] = None,
        d_vectors_file_path: str = "",
        speaker_id_file_path: str = "",
        encoder_model_path: str = "",
        encoder_config_path: str = "",
        use_cuda: bool = False,
    ):

        self.d_vectors = {}
        self.speaker_ids = {}
        self.clip_ids = []
        self.speaker_encoder = None
        self.speaker_encoder_ap = None
        self.use_cuda = use_cuda

        if data_items:
            self.speaker_ids, _ = self.parse_speakers_from_data(data_items)

        if d_vectors_file_path:
            self.set_d_vectors_from_file(d_vectors_file_path)

        if speaker_id_file_path:
            self.set_speaker_ids_from_file(speaker_id_file_path)

        if encoder_model_path and encoder_config_path:
            self.init_speaker_encoder(encoder_model_path, encoder_config_path)

看到最后一行初始化了我们的speaker encoder

再看一下第一行!

没有继承nn.Module,甚至什么都没有继承!

因为没有继承nn.Module,所以没有在计算图里,参数里也没有speaker encoder 的参数,所以speaker encoder还是固定住的。

下面还是自己实验一下,看看结论是否正确

from itertools import chain
import torch
from torch import nn


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 5)
        nn.init.constant_(self.linear.weight, 1)
        nn.init.constant_(self.linear.bias, 1)

    def forward(self, x):
        return self.linear(x)

    @torch.no_grad()
    def infer(self, x):
        return self.forward(x)


n1 = MyNet().eval()
n2 = MyNet()
test_data = torch.ones([1, 2, 5])

x1 = n1(test_data)
x2 = n2(x1)

print(x1)
print(x2)

l = nn.MSELoss()(x2, torch.ones([1, 2, 5]))
l.backward()
# o = torch.optim.SGD( chain( n1.parameters(),n2.parameters() ), lr = 0.1)
o = torch.optim.SGD([{'params': n1.parameters()},
                     {'params': n2.parameters()}], 0.1)
# o = torch.optim.SGD([{'params': n2.parameters()}], 0.1)
o.step()

x1 = n1(test_data)
x2 = n2(x1)
print(x1)
print(x2)

"""
# 第一次的x1和x2
tensor([[[6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6.]]], grad_fn=<AddBackward0>)
tensor([[[31., 31., 31., 31., 31.],
         [31., 31., 31., 31., 31.]]], grad_fn=<AddBackward0>)
# 第二次的x1和x2
tensor([[[-30., -30., -30., -30., -30.],
         [-30., -30., -30., -30., -30.]]], grad_fn=<AddBackward0>)
tensor([[[929.8000, 929.8000, 929.8000, 929.8000, 929.8000],
         [929.8000, 929.8000, 929.8000, 929.8000, 929.8000]]],
       grad_fn=<AddBackward0>)
"""

我们先初始化一个固定参数的网络,只跑一个step,发现参数都是有更新的

# from tacotron2_model.utils import get_config
# config = get_config('./config/aishell3_gmm.yaml')

# flag = 'multi_speaker' in config.data.__dict__ and config.data.multi_speaker == True
# print(flag)
# config = get_config('./config/base.yaml')
# flag = 'multi_speaker' in config.data.__dict__ and config.data.multi_speaker == True
# print(flag)
from itertools import chain
import torch
from torch import nn


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 5)
        nn.init.constant_(self.linear.weight, 1)
        nn.init.constant_(self.linear.bias, 1)

    def forward(self, x):
        return self.linear(x)

    @torch.no_grad()
    def infer(self, x):
        return self.forward(x)


n1 = MyNet().eval()
n2 = MyNet()
test_data = torch.ones([1, 2, 5])

x1 = n1(test_data)
x2 = n2(x1)

print(x1)
print(x2)

l = nn.MSELoss()(x2, torch.ones([1, 2, 5]))
l.backward()
# o = torch.optim.SGD( chain( n1.parameters(),n2.parameters() ), lr = 0.1)
# o = torch.optim.SGD([{'params': n1.parameters()},
#                      {'params': n2.parameters()}], 0.1)
o = torch.optim.SGD([{'params': n2.parameters()}], 0.1)
o.step()

x1 = n1(test_data)
x2 = n2(x1)
print(x1)
print(x2)
"""
tensor([[[6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6.]]], grad_fn=<AddBackward0>)
tensor([[[31., 31., 31., 31., 31.],
         [31., 31., 31., 31., 31.]]], grad_fn=<AddBackward0>)
tensor([[[6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6.]]], grad_fn=<AddBackward0>)
tensor([[[-186.2000, -186.2000, -186.2000, -186.2000, -186.2000],
         [-186.2000, -186.2000, -186.2000, -186.2000, -186.2000]]],
       grad_fn=<AddBackward0>)

"""

紧接着只优化n2这个网络,发现n1的参数的确没有更新,n2的参数是更新了的

from turtle import forward
import torch
from torch import nn


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(5, 5)
        nn.init.constant_(self.linear.weight, 1)
        nn.init.constant_(self.linear.bias, 1)

    def forward(self, x):
        return self.linear(x)

    @torch.no_grad()
    def infer(self, x):
        return self.forward(x)
class t:
    def __init__(self) -> None:
        self.tt = MyNet()

class Big_net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.a = MyNet()
        self.t_man = t()

    def forward(self, x):
        x =self.t_man.tt.forward(x)
        print('x1', x)
        x = self.a(x)
        print('x2', x)
        return x


bn = Big_net()
test_data = torch.ones([1, 2, 5])

x2 = bn(test_data)

l = nn.MSELoss()(x2, torch.ones([1, 2, 5]))
l.backward()
o = torch.optim.SGD(bn.parameters(), 0.1)
o.step()

x2 = bn(test_data)
# print(x1)
print(bn)

"""
x1 tensor([[[6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6.]]], grad_fn=<AddBackward0>)
x2 tensor([[[31., 31., 31., 31., 31.],
         [31., 31., 31., 31., 31.]]], grad_fn=<AddBackward0>)
x1 tensor([[[6., 6., 6., 6., 6.],
         [6., 6., 6., 6., 6.]]], grad_fn=<AddBackward0>)
x2 tensor([[[-186.2000, -186.2000, -186.2000, -186.2000, -186.2000],
         [-186.2000, -186.2000, -186.2000, -186.2000, -186.2000]]],
       grad_fn=<AddBackward0>)
Big_net(
  (a): MyNet(
    (linear): Linear(in_features=5, out_features=5, bias=True)
  )
)
"""

我们使用同样的写法,定义了一个类t,用self.t_man.tt模拟n1,用self.a模拟n2,发现结果和预想的完全一致,并且打印bn网络,发现的确没有t里面网络的参数。算破案了吧

总体来说,固定网络还是有不少方法的,比如.detach,torch.no_grad,不在优化器里加参数,甚至可以每次更新完参数再把原来的参数赋回去。今天算又学到一种新的方式,如果还有什么其他方式,欢迎在评论区交流一下~

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值