flow-matching 之学习matcha-tts & cosyvoice

matcha 实现

def fm_comput_loss()
	# x1 是target_mel
	# random timestep
    t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
    # sample noise p(x_0)
    z = torch.randn_like(x1)

    y = (1 - (1 - self.sigma_min) * t) * z + t * x1
    u = x1 - (1 - self.sigma_min) * z
	pred_y = self.estimator(y, mask, mu, t.squeeze(), spks)
    loss = F.mse_loss(pred_y, u, reduction="sum") / (
        torch.sum(mask) * u.shape[1]
    )
    return loss, y
def estimator_forward():
	x = pack(y, mu)
	x = pack(x, spks)
	q,k,v = x, x, x
	x = slf_attn(q,k,v)
	outputs = linear(x)
	return outputs

cosyvoice 实现



        
def fm_forward():
    # mu: encoder_outputs
    # x1: target_mel
    # cond: prompt_mel 随机取的部分

    conds = torch.zeros(feat.shape, device=token.device)
    for i, j in enumerate(feat_len):
        if random.random() < 0.5:
            continue
        index = random.randint(0, int(0.3 * j))
        conds[i, :index] = feat[i, :index]
    conds = conds.transpose(1, 2)

    b, _, t = mu.shape

    # random timestep
    t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
    if self.t_scheduler == 'cosine':
        t = 1 - torch.cos(t * 0.5 * torch.pi)
    # sample noise p(x_0)
    z = torch.randn_like(x1)

    y = (1 - (1 - self.sigma_min) * t) * z + t * x1
    u = x1 - (1 - self.sigma_min) * z

    # during training, we randomly drop condition to trade off mode coverage and sample fidelity
    # inference 的时候实际不需要condition, 给zero就可以
    if self.training_cfg_rate > 0:
        cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
        mu = mu * cfg_mask.view(-1, 1, 1)
        spks = spks * cfg_mask.view(-1, 1)
        cond = cond * cfg_mask.view(-1, 1, 1)

    pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
    loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
    return loss, y

def estimator(x, mu, spks, cond):
    x = pack(x, mu, spks, cond)
    x = slf_attn(x)
    outputs = linear(x)
    return outputs

chunk_fm

  • 训练的时候将特征进行chunk_mask,推理的时候只准备chunk的部分,pre_chunk 存为kv_cache,
  • cache 初始seq_len为0;每次得到的cache,只留下[-chunk_len:] 的长度,作为下一次的输入;特征x 的pos 按照真的来算;

chunk_mask

在这里插入图片描述

  • 训练阶段样本按照seq_len 维度被mask 成不同的可见部分;chunk_mask 和长度mask 都会出现,为了加速收敛;

cache_attn

def slf_attn_cache(x, cache):
    k_in, v_in, q_in = x, x, x
	key_cache = linear1(k_in)
    value_cache = linear2(v_in)
    # NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)
    if cache.size(0) != 0:
        # step into this branch
        key = torch.concat([cache[:, :, :, 0], key_cache], dim=1)
        value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)
    else:
        key, value = key_cache, value_cache
    cache = torch.stack([key_cache, value_cache], dim=3)
    outputs = scale_dot_production(key, value)
    return outputs, cache

stream token2wav

在这里插入图片描述

  • 第一个包没有kv_cache, 卷积cache 有,但是值为0;first chunk 推理完就可以存下kv_cache & cnn_cache;
  • 输入token+cache_token,得到token 对应的mel;
  • mel2wav 阶段也是,第一次没有hift cache,直接退出mel 对应的wav,最后8帧存下来作为hift_cache,用于hift_wav 预测以及输出的音频片段间平滑;前n-8 帧的音频输出;
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)):
    # mel->f0 mel:[1,80,T]
    #print('hift inference speech_feat', speech_feat.size())
    f0 = self.f0_predictor(speech_feat)
    # f0->source
    s = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,t
    s, _, _ = self.m_source(s)
    s = s.transpose(1, 2) #sample
    #print('f0 s', s.size()) # sample level
    # use cache_source to avoid glitch
    if cache_source.shape[2] != 0:
        # s[1,1,t*480]
        s[:, :, :cache_source.shape[2]] = cache_source
        print('cache_source s', s.size())
    else:
        print('cache_source shape2 is 0')
    
    #print('hift inference s', s.size())
    generated_speech = self.decode(x=speech_feat, s=s)
    return generated_speech, s
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值