腾讯音乐TME-MuseTalk模型代码详解


git地址:https://github.com/TMElyralab/MuseTalk

一、模型理解

在这里插入图片描述

1.1 训练方式

类似于stable diffusion的流程,原始人脸与mask唇部图片先经过训练好的Encoder转化为隐向量concat一起,经过Unet预测每一步生成的噪声,语音信息【./models/whisper/tiny.pt生成】作为一个条件机制加入到训练过程中。

1.2 Loss

L 1 L_1 L1预测的隐向量和真实图片的隐向量的距离
L 2 L_2 L2原图和生成图之间的差异

二、测试代码流程

2.1 main函数

总流程:

  • 输入视频/图片、音频
  • 对于音频,调用Audio2Feature类进行音频编码、分块(第i帧图片对应音频块[i-2,i+2])
  • 如果输入的是视频,分帧、提取提取人脸关键点、提取人脸边界框(手动mask人脸下半部分)
  • 模型生成每一帧图片
  • 裁剪图片与背景合成每一帧
  • FFmpeg 生成视频
def main(args):
    global pe
    if args.use_float16 is True:
        pe = pe.half()
        vae.vae = vae.vae.half()
        unet.model = unet.model.half()
    
    inference_config = OmegaConf.load(args.inference_config)
    print(inference_config)
#{'task_0': {'video_path': './data/video/yongen.mp4', 'audio_path': './data/audio/yongen.wav'}, 'task_1': {'video_path': './data/video/sun.mp4', 'audio_path': './data/audio/sun.wav', 'bbox_shift': -7}}

    for task_id in inference_config:
        video_path = inference_config[task_id]["video_path"]
        audio_path = inference_config[task_id]["audio_path"]
        bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)

        input_basename = os.path.basename(video_path).split('.')[0]
        audio_basename  = os.path.basename(audio_path).split('.')[0]
        output_basename = f"{input_basename}_{audio_basename}"
        result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs
        crop_coord_save_path = os.path.join(result_img_save_path, input_basename+".pkl") # only related to video input
        os.makedirs(result_img_save_path,exist_ok =True)
        
        if args.output_vid_name is None:
        #./results/yongen_yongen.mp4'
            output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
        else:
            output_vid_name = os.path.join(args.result_dir, args.output_vid_name)
        ############################################## extract frames from source video ##############################################
        #如果输入是一个视频,切帧
        if get_file_type(video_path)=="video":
            save_dir_full = os.path.join(args.result_dir, input_basename)
            os.makedirs(save_dir_full,exist_ok = True)
            #提取视频的每一帧,总帧数:帧率*时长s
            cmd = f"ffmpeg -v fatal -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
            os.system(cmd)
            input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
            fps = get_video_fps(video_path)
        elif get_file_type(video_path)=="image":
            input_img_list = [video_path, ]
            fps = args.fps
        elif os.path.isdir(video_path):  # input img folder
            input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
            input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
            fps = args.fps
        else:
            raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")

        #print(input_img_list)
        ############################################## extract audio feature ##############################################
        #调用Audio2Feature类进行音频编码
        whisper_feature = audio_processor.audio2feat(audio_path)
        #对音频分块
        whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
        ############################################## preprocess input image  ##############################################
        if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
            print("using extracted coordinates")
            with open(crop_coord_save_path,'rb') as f:
                coord_list = pickle.load(f)
            frame_list = read_imgs(input_img_list)
        else:
            print("extracting landmarks...time consuming")
            #获取每一帧的关键点,提取人脸关键点的第24到第91个关键点和边界框
            coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
            with open(crop_coord_save_path, 'wb') as f:
                pickle.dump(coord_list, f)
                
        i = 0
        input_latent_list = []
        for bbox, frame in zip(coord_list, frame_list):
            if bbox == coord_placeholder:
                continue
            x1, y1, x2, y2 = bbox
            # 裁剪人脸位置
            crop_frame = frame[y1:y2, x1:x2]
            #使用 OpenCV 将裁剪后的图像帧 crop_frame 调整为 256x256 像素,使用 INTER_LANCZOS4 进行插值
            crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
            #使用变分自编码器(VAE)的 get_latents_for_unet 方法获取裁剪图像帧的潜在表示 latents
            latents = vae.get_latents_for_unet(crop_frame)
            input_latent_list.append(latents)
    
        # to smooth the first and the last frame
        frame_list_cycle = frame_list + frame_list[::-1]
        coord_list_cycle = coord_list + coord_list[::-1]
        input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
        ############################################## inference batch by batch ##############################################
        print("start inference")
        video_num = len(whisper_chunks)
        batch_size = args.batch_size
        gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
        res_frame_list = []
        for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
        #音频特征批量转换为 PyTorch 张量,并发送到指定设备上
            audio_feature_batch = torch.from_numpy(whisper_batch)
            audio_feature_batch = audio_feature_batch.to(device=unet.device,
                                                         dtype=unet.model.dtype) # torch, B, 5*N,384
            audio_feature_batch = pe(audio_feature_batch)
            latent_batch = latent_batch.to(dtype=unet.model.dtype)
            #得到的潜在表示生成重建帧 pred_latents 
            pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
            #解码 帧
            recon = vae.decode_latents(pred_latents)
            for res_frame in recon:
                res_frame_list.append(res_frame)
                
        ############################################## pad to full image ##############################################
        print("pad talking image to original video")
        for i, res_frame in enumerate(tqdm(res_frame_list)):
            #获取当前结果帧的边界框 bbox 和对应的原始帧 ori_frame
            bbox = coord_list_cycle[i%(len(coord_list_cycle))]
            ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
            x1, y1, x2, y2 = bbox
            try:
            #使用 OpenCV 将生成的结果帧 res_frame 调整大小,确保与原始帧的边界框大小相同
                res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
            except:
#                 print(bbox)
                continue
            #调用 get_image 函数将原始帧和生成的结果帧结合起来,以获取最终的合成帧
            combine_frame = get_image(ori_frame,res_frame,bbox)
            #将合成帧写入临时图片文件夹中,文件名使用八位数进行编号
            cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
        #使用 FFmpeg 将临时图片文件夹中的合成帧转换为临时视频文件 temp.mp4
        cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
        print(cmd_img2video)
        os.system(cmd_img2video)
        #使用 FFmpeg 将原始音频文件和临时视频文件合成为最终的输出视频文件 output_vid_name
        cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i temp.mp4 {output_vid_name}"
        print(cmd_combine_audio)
        os.system(cmd_combine_audio)
        
        os.remove("temp.mp4")
        shutil.rmtree(result_img_save_path)
        print(f"result is save to {output_vid_name}")

三、语音信息的提取Whisper:

3.1 Audio2Feature获取音频信息过程

使用Whisper提取音频信息,总流程如下,后面会介绍每个函数快作用

class Audio2Feature():
    def __init__(self, 
                 whisper_model_type="tiny",
                 model_path="./models/whisper/tiny.pt"):
        self.whisper_model_type = whisper_model_type
        self.model = load_model(model_path) #

    def get_sliced_feature(self,
                           feature_array, 
                           vid_idx, 
                           audio_feat_length=[2,2],
                           fps=25):
        """
        Get sliced features based on a given index
        :param feature_array: 
        :param start_idx: the start index of the feature
        :param audio_feat_length:
        :return: 
        """
        length = len(feature_array)
        selected_feature = []
        selected_idx = []
        
        center_idx = int(vid_idx*50/fps) 
        left_idx = center_idx-audio_feat_length[0]*2
        right_idx = center_idx + (audio_feat_length[1]+1)*2
        
        for idx in range(left_idx,right_idx):
            idx = max(0, idx)
            idx = min(length-1, idx)
            x = feature_array[idx]
            selected_feature.append(x)
            selected_idx.append(idx)
        
        selected_feature = np.concatenate(selected_feature, axis=0)
        selected_feature = selected_feature.reshape(-1, 384)# 50*384
        return selected_feature,selected_idx

    def get_sliced_feature_sparse(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
        """
        Get sliced features based on a given index
        :param feature_array: 
        :param start_idx: the start index of the feature
        :param audio_feat_length:
        :return: 
        """
        length = len(feature_array)
        selected_feature = []
        selected_idx = []

        for dt in range(-audio_feat_length[0],audio_feat_length[1]+1):
            left_idx = int((vid_idx+dt)*50/fps)
            if left_idx<1 or left_idx>length-1:
                left_idx = max(0, left_idx)
                left_idx = min(length-1, left_idx)

                x = feature_array[left_idx]
                x = x[np.newaxis,:,:]
                x = np.repeat(x, 2, axis=0)
                selected_feature.append(x)
                selected_idx.append(left_idx)
                selected_idx.append(left_idx)
            else:
                x = feature_array[left_idx-1:left_idx+1]
                selected_feature.append(x)
                selected_idx.append(left_idx-1)
                selected_idx.append(left_idx)
        selected_feature = np.concatenate(selected_feature, axis=0)
        selected_feature = selected_feature.reshape(-1, 384)# 50*384
        return selected_feature,selected_idx
    

    def feature2chunks(self,feature_array,fps,audio_feat_length = [2,2]):
        whisper_chunks = []
        whisper_idx_multiplier = 50./fps 
        i = 0
        print(f"video in {fps} FPS, audio idx in 50FPS")
        while 1:
            start_idx = int(i * whisper_idx_multiplier)
            selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i,audio_feat_length=audio_feat_length,fps=fps)
            #print(f"i:{i},selected_idx {selected_idx}")
            whisper_chunks.append(selected_feature)
            i += 1
            if start_idx>len(feature_array):
                break

        return whisper_chunks

    def audio2feat(self,audio_path):
        # get the sample rate of the audio
        result = self.model.transcribe(audio_path)
        embed_list = []
        for emb in result['segments']:
            encoder_embeddings = emb['encoder_embeddings']
            encoder_embeddings = encoder_embeddings.transpose(0,2,1,3)
            encoder_embeddings = encoder_embeddings.squeeze(0)
            start_idx = int(emb['start'])
            end_idx = int(emb['end'])
            emb_end_idx = int((end_idx - start_idx)/2)
            embed_list.append(encoder_embeddings[:emb_end_idx])
        concatenated_array = np.concatenate(embed_list, axis=0)
        return concatenated_array

3.2 音频信息首先转化为对数Mel频谱图(log-Mel spectrogram)

3.2.1 使用FFmpeg将音频文件转换为16kHz采样率的单声道音频,然后将其转换为NumPy数组.

def load_audio(file: str, sr: int = SAMPLE_RATE):
    try:
        # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
        # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
       #"-" 表示输出到标准输出(stdout),format="s16le" 指定输出格式为 16 位小端序 PCM
       #acodec="pcm_s16le" 指定音频编解码器为 PCM 16 位小端序,ac=1 指定音频通道数为 1(单声道),ar=sr 指定音频采样率
       。
        out, _ = (
            ffmpeg.input(file, threads=0)
            .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
            .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
        )
    except ffmpeg.Error as e:
        raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
    #np.int16 类型表示 16 位有符号整数,#PCM 16 位音频使用有符号 16 位整数来表示音频样本值,其取值范围是 [-32768, 32767]
    #np.frombuffer(out, np.int16) 将字节缓冲区转换为 16 位整数的 NumPy 数组
    #flatten() 将数组展平成一维数组,astype(np.float32) 将数组转换为 float32 类型
    #除以 32768.0 将整数值归一化到浮点数范围(-1.0 到 1.0)
    return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0

3.2.2 计算log-Mel频谱图

def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
    """
    if not torch.is_tensor(audio):
        if isinstance(audio, str):
            audio = load_audio(audio)
        audio = torch.from_numpy(audio)
    #创建一个Hann窗口,长度为 N_FFT,400
    window = torch.hann_window(N_FFT).to(audio.device)
    #计算音频的短时傅里叶变换(STFT),N_FFT【400】是FFT窗口的大小,HOP_LENGTH【160】是窗口之间的跳跃长度,window 是Hann窗口,return_complex=True 返回复数结果。
    stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
    #计算STFT结果的幅度平方,忽略最后一个频率成分(对称性)
    magnitudes = stft[:, :-1].abs() ** 2
    #创建梅尔滤波器,数量为 n_mels【80】
    filters = mel_filters(audio.device, n_mels)
    #将幅度平方结果乘以梅尔滤波器,得到梅尔频谱图。
    mel_spec = filters @ magnitudes
    #对梅尔频谱图进行对数变换
    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec

3.2.3 提取audio特征

audio[128000] -> Mel [80,800] -> pad_or_trim填充[1, 80, 3000] ->audioFeature(encoder_embeddings) [1, 5, 1500, 384] -> whisper_feature[400,5,384]

def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
    """
    Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
    """
    #对音频 进行填充或裁剪,使其在指定的轴上达到指定的长度
    if torch.is_tensor(array):
        if array.shape[axis] > length:
            array = array.index_select(dim=axis, index=torch.arange(length))

        if array.shape[axis] < length:
            pad_widths = [(0, 0)] * array.ndim
            pad_widths[axis] = (0, length - array.shape[axis])
            array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
    else:
        if array.shape[axis] > length:
            array = array.take(indices=range(length), axis=axis)

        if array.shape[axis] < length:
            pad_widths = [(0, 0)] * array.ndim
            pad_widths[axis] = (0, length - array.shape[axis])
            array = np.pad(array, pad_widths)

    return array

3.3 音频分块

3.3.1 feature2chunks

  • 根据视频帧率 fps 和假设的音频特征帧率(50 FPS)计算每帧视频对应的音频特征索引。
  • 使用 get_sliced_feature 方法从特征数组中提取特定长度的特征块。
  • 将提取的特征块存储到 whisper_chunks 列表中,直到处理完所有特征数据。
    def feature2chunks(self,feature_array,fps,audio_feat_length = [2,2]):
        whisper_chunks = []
        #将视频帧索引转换为音频特征的索引。这里假设音频特征的帧率是 50 FPS
        whisper_idx_multiplier = 50./fps 
        i = 0
        print(f"video in {fps} FPS, audio idx in 50FPS")
        while 1:
            start_idx = int(i * whisper_idx_multiplier)
            #从 feature_array 中提取一个特征块和对应的索引
            selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i,audio_feat_length=audio_feat_length,fps=fps)
            #print(f"i:{i},selected_idx {selected_idx}")
            whisper_chunks.append(selected_feature)
            i += 1
            if start_idx>len(feature_array):
                break

        return whisper_chunks

3.3.2 get_sliced_feature

计算给定视频帧索引对应的音频特征中心索引。
根据指定的音频特征长度,从特征数组中提取一个窗口。
返回提取的特征块及其对应的索引。

    def get_sliced_feature(self,
                           feature_array, 
                           vid_idx, 
                           audio_feat_length=[2,2],
                           fps=25):
        """
        Get sliced features based on a given index
        :param feature_array: 
        :param start_idx: the start index of the feature
        :param audio_feat_length:
        :return: 
        """
        length = len(feature_array)
        selected_feature = []
        selected_idx = []
        
        center_idx = int(vid_idx*50/fps) 
        #计算音频特征窗口的左右边界索引 left_idx 和 right_idx
        left_idx = center_idx-audio_feat_length[0]*2
        right_idx = center_idx + (audio_feat_length[1]+1)*2
        #遍历从 left_idx 到 right_idx 的索引范围,确保索引在有效范围内,并提取对应的特征
        for idx in range(left_idx,right_idx):
            idx = max(0, idx)###确保不小于0
            idx = min(length-1, idx)#确保索引不超过数组的最大索引
            x = feature_array[idx]
            selected_feature.append(x)
            selected_idx.append(idx)
        
        selected_feature = np.concatenate(selected_feature, axis=0)
        #将拼接后的特征数组重塑为形状 (-1, 384),其中每个片段的大小为 384
        selected_feature = selected_feature.reshape(-1, 384)# 50*384
        return selected_feature,selected_idx

4、生成效果

因为是一张一张生成图片,然后合成视频的,可以对比每一帧效果
左真实,右合成效果

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值