视频增强实现

由于SadTalker生成的视频清晰度不是很高,单帧图片像素只有224*224,所以我考虑给实现视频画质增强,使用对抗生成网络是个不错的选择

用户可以选择标准或高清数字人视频,选择标准意味着牺牲一些画质获得更快的响应速度;选择高清视频意味着牺牲一些响应速度获得更高的画质

具体做法:
在animate.py增加一个判断语句,看传入参数enhancer是否真,为真的话调用enhancer_generator_with_len方法,得到的增强视频再和音频结合再保存到本地:

if enhancer:
            video_name_enhancer = x['video_name']  + '_enhanced.mp4'
            enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
            av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) 
            return_path = av_path_enhancer


            enhanced_images_gen_with_len = enhancer_generator_with_len(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
            imageio.mimsave(enhanced_path, enhanced_images_gen_with_len, fps=float(fps))
            save_video_with_watermark(enhanced_path, audio_path, av_path_enhancer, watermark= False)
            # print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
            os.remove(enhanced_path)

        os.remove(path)

        return return_path

视频画质增强可以看作视频中的每个帧对应的图像画质增强。比较有名的人脸增强画质的生成对抗网络就是GFPGAN:

GFPGAN(Generative Facial Prior-Generative Adversarial Network)是一个基于生成对抗网络(GAN)的深度学习模型,主要用于人脸图像的修复和增强。它能够处理低分辨率、模糊或部分损坏的人脸图像,通过学习大量人脸数据,生成高质量、清晰的人脸图像

GFPGAN 利用了GAN的能力,通过竞争学习的方式,生成器(Generator)学习如何产生逼真的人脸图像,而判别器(Discriminator)学习如何区分生成的图像与真实图像。这种方法使得GFPGAN能够在保持人脸自然性和真实性的同时,有效地提升图像质量。

在Linux中通过wget下载对应的训练好的模型权重文件:

mkdir -p ./gfpgan/weights
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth 
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth 
wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth 
wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth 

然后pip安装gfpgan

def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
    """ Provide a generator with a __len__ method so that it can passed to functions that
    call len()"""

    if os.path.isfile(images): # handle video to images
        # TODO: Create a generator version of load_video_to_cv2
        images = load_video_to_cv2(images)

    gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
    gen_with_len = GeneratorWithLen(gen, len(images))
    return gen_with_len
  1. 函数的目的是创建一个具有__len__方法的生成器,以便可以将其传递给调用len()函数的其他函数。
  2. 函数的第一步是检查images是否是一个文件路径。如果是文件路径,则调用load_video_to_cv2函数将视频转换为图像序列。
  3. 接下来,函数调用enhancer_generator_no_len函数,该函数返回一个生成器对象。这个生成器对象用于生成增强后的图像。
  4. 然后,函数创建一个GeneratorWithLen对象,将之前的生成器对象和图像序列的长度作为参数传递给它

load_video_to_cv2实际上就是将视频帧逐个取出转opencv格式再集合并返回列表

def load_video_to_cv2(input_path):
    video_stream = cv2.VideoCapture(input_path)
    fps = video_stream.get(cv2.CAP_PROP_FPS)
    full_frames = [] 
    while 1:
        still_reading, frame = video_stream.read()
        if not still_reading:
            video_stream.release()
            break 
        full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    return full_frames

然后就是最主要的函数enhancer_generator_no_len,用于逐帧增强,首先判断是否安装了gfpan,若没有则安装,再给出对应的URL,让没有模型文件的时候可以自动下载模型文件:

def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
    """ Provide a generator function so that all of the enhanced images don't need
    to be stored in memory at the same time. This can save tons of RAM compared to
    the enhancer function. """
    try:
        from gfpgan import GFPGANer
    except ImportError:
        print("GFPGAN library not found. Installing...")
        try:
            # Use pip to install the library
            import subprocess
            subprocess.check_call(["pip", "install", "gfpgan"])
            
            # Retry the import after installation
            from gfpgan import GFPGANer
            print("GFPGAN library installed successfully!")
        except Exception as e:
            print(f"Failed to install GFPGAN library. Error: {e}")
            # Handle the error or raise it again if needed
        
    print('face enhancer....')
    if not isinstance(images, list) and os.path.isfile(images): # handle video to images
        images = load_video_to_cv2(images)

    # ------------------------ set up GFPGAN restorer ------------------------
    if  method == 'gfpgan':
        arch = 'clean'
        channel_multiplier = 2
        model_name = 'GFPGANv1.4'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
    elif method == 'RestoreFormer':
        arch = 'RestoreFormer'
        channel_multiplier = 2
        model_name = 'RestoreFormer'
        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
    elif method == 'codeformer': # TODO:
        arch = 'CodeFormer'
        channel_multiplier = 2
        model_name = 'CodeFormer'
        url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
    else:
        raise ValueError(f'Wrong model version {method}.')

最后就是逐帧增强了,遍历每个帧,过GFPGANer的enhancer方法,得到增强的每个帧数据并返回:

model_path = os.path.join('gfpgan/weights', model_name + '.pth')
    
    if not os.path.isfile(model_path):
        model_path = os.path.join('checkpoints', model_name + '.pth')
    
    if not os.path.isfile(model_path):
        # download pre-trained models from url
        model_path = url

    restorer = GFPGANer(
        model_path=model_path,
        upscale=2,
        arch=arch,
        channel_multiplier=channel_multiplier,
        bg_upsampler=bg_upsampler)

    # ------------------------ restore ------------------------
    for idx in tqdm(range(len(images)), 'Face Enhancer:'):
        print('face enhancer....')
        img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
        
        # restore faces and background if necessary
        cropped_faces, restored_faces, r_img = restorer.enhance(
            img,
            has_aligned=False,
            only_center_face=False,
            paste_back=True)
        
        r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
        yield r_img

这样就得到了增强画质的视频,并存储到了本地

最后是封装并部署,该内容与之前重复,在此不赘述

  • 8
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值