GFPGAN源码分析—第三篇

2021SC@SDUSC

源码:utils.py

本篇主要分析utils.py中的class GFPGANer ( )的初始化以及load_file_from_url( )方法

目录

1.获取当前项目路径

2.class GFOGANer ( )——init()

(1)优先选择在cupy+gpu上运行

(2)根据参数arch选择性初始化GFP-GAN

(3)初始化face helper

(4)增加了一个model路径是网址时的处理,然而需要的model已经下载到本地,并没有用到

(5)读取model并继续初始化

3.load_file_from_url( )


1.获取当前项目路径

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

os.path.abspath(file)获取当前文件的绝对路径C:\Users\Vaifer\Desktop\GFPGAN-v.0.2.1\gfpgan\utils.py

os.path.dirname()再获取该文件所在的目录路径C:\Users\Vaifer\Desktop\GFPGAN-v.0.2.1\gfpgan

最终应该得到C:\xxx\GFPGAN-v.0.2.

2.class GFOGANer ( )——init()

参数:(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None)

(1)优先选择在cupy+gpu上运行

self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(2)根据参数arch选择性初始化GFP-GAN

if arch == 'clean':
    self.gfpgan = GFPGANv1Clean(
        out_size=512,
        num_style_feat=512,
        channel_multiplier=channel_multiplier,
        decoder_load_path=None,
        fix_decoder=False,
        num_mlp=8,
        input_is_latent=True,
        different_w=True,
        narrow=1,
        sft_half=True)
else:
    self.gfpgan = GFPGANv1(
        out_size=512,
        num_style_feat=512,
        channel_multiplier=channel_multiplier,
        decoder_load_path=None,
        fix_decoder=True,
        num_mlp=8,
        input_is_latent=True,
        different_w=True,
        narrow=1,
        sft_half=True)

可以看到分别调用了GFPGANv1Clean与GFPGANv1进行初始化,之后我们会具体分析这两个类

(3)初始化face helper

这边就使用到了facexlib包中的face restoration helper

self.face_helper = FaceRestoreHelper(
    upscale,
    face_size=512,
    crop_ratio=(1, 1),
    det_model='retinaface_resnet50',
    save_ext='png',
    device=self.device)

(4)增加了一个model路径是网址时的处理,然而需要的model已经下载到本地,并没有用到

if model_path.startswith('https://'):

(5)读取model并继续初始化

loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
    keyname = 'params_ema'
else:
    keyname = 'params'
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)

3.load_file_from_url( )

从指定url中下载文件并读取的一个函数,简单介绍下

在读取model时如果路径是网址,就会调用这个函数下载相应的model

参数:(url, model_dir=None, progress=True, file_name=None)

def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
    """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
    """
    hub_dir = get_dir()
    model_dir = os.path.join(hub_dir, 'checkpoints')
    print('hub_dir',hub_dir)
    print('model_dir',model_dir)
    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')
    #做路径的拼接,并递归创建目录
    os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
    if not os.path.exists(cached_file):
        print(f'Downloading: "{url}" to {cached_file}\n')
        download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
    return cached_file

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
GFP-GAN是一种基于生成对抗网络的图像超分辨率重建方法,可以从低分辨率图像生成高分辨率图像。下面对GFP-GAN的源码进行简要分析。 GFP-GAN源码的主要组成部分包括生成器和判别器两个网络。生成器网络负责将给定的低分辨率图像作为输入,生成高分辨率图像作为输出。判别器网络则用于判断生成器生成的图像是否足够逼真。生成器和判别器网络通过对抗学习的方式进行训练,不断优化生成器的生成效果,使其生成的图像尽可能接近真实高分辨率图像。 GFP-GAN中使用了一种特殊的损失函数,包括感知损失和对抗损失。感知损失是通过计算生成图像与真实高分辨率图像之间的特征差异来衡量生成图像的质量。对抗损失则是通过判别器网络来评估生成器生成的图像是否逼真,鼓励生成器生成更真实的图像。 在源码中,可以看到生成器和判别器网络的结构定义和参数设置。还有训练过程中的数据处理部分,包括数据加载、预处理和模型训练等。此外,源码中可能还包含了一些辅助函数和工具函数,用于辅助训练和评估过程。 通过分析源码,可以深入了解GFP-GAN的具体实现细节和网络结构。同时,还可以对训练过程中的超参数设置、损失函数设计等进行调整和优化,以进一步提高GFP-GAN的生成效果和性能。 总之,通过对GFP-GAN源码分析,可以更好地理解该方法的原理和实现方式,为后续的研究和应用提供基础和参考。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值