GFPGAN源码分析—第十二篇

2021SC@SDUSC

源码:

models\gfpgan_model.py

本篇分析models\gfpgan_model.py下的

class GFPGANModel(BaseModel) 类的部分方法

class GFPGANModel(BaseModel)

目录

class GFPGANModel(BaseModel)

test(self)

dist_validation()

nondist_validation()


test(self)

测试

def test(self):
    #使用 with torch.no_grad():,强制之后的内容不进行计算图构建。
    with torch.no_grad():
        if hasattr(self, 'net_g_ema'):
            self.net_g_ema.eval()
            self.output, _ = self.net_g_ema(self.lq)
        else:
            logger = get_root_logger()
            logger.warning('Do not have self.net_g_ema, use self.net_g.')
            self.net_g.eval()
            self.output, _ = self.net_g(self.lq)
            self.net_g.train()

dist_validation()

def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    if self.opt['rank'] == 0:
        #调用nondist_validation函数进行处理
        self.nondist_validation(dataloader, current_iter, tb_logger, save_img)

nondist_validation()

参数:
self, dataloader, current_iter, tb_logger, save_img

分几步看一下代码

1.进度条与with_metrics的初始化

dataset_name = dataloader.dataset.opt['name']
#确认with_metrics is not None
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics:
    self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
    #进度条
pbar = tqdm(total=len(dataloader), unit='image')

2.遍历dataloader,做fead data以及图像变换保存等

for idx, val_data in enumerate(dataloader):
    #分离文件名与扩展名,返回一个元组。
    img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
    #调用fead_data处理val_data
    self.feed_data()
    self.test()
	#调用get_current_visuals
    visuals = self.get_current_visuals()
    #将torch张量转换为图像numpy数组
    sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
    gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))

    if 'gt' in visuals:
        gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
        del self.gt
    # tentative for out of GPU memory
    del self.lq
    del self.output
    torch.cuda.empty_cache()
#如果需要保存图片
    if save_img:
        #首先设置路径
        if self.opt['is_train']:
            save_img_path = osp.join(self.opt['path']['visualization'], img_name,
                                     f'{img_name}_{current_iter}.png')
        else:
            if self.opt['val']['suffix']:
                save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
                                         f'{img_name}_{self.opt["val"]["suffix"]}.png')
            else:
                save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
                                         f'{img_name}_{self.opt["name"]}.png')
        imwrite(sr_img, save_img_path)

    if with_metrics:
        # calculate metrics
        for name, opt_ in self.opt['val']['metrics'].items():
            metric_data = dict(img1=sr_img, img2=gt_img)
            self.metric_results[name] += calculate_metric(metric_data, opt_)
     #更新进度条
    pbar.update(1)
    pbar.set_description(f'Test {img_name}')
pbar.close()

3.调用_log_validation_metric_values

#with_metrics一定为True
if with_metrics:
    for metric in self.metric_results.keys():
        self.metric_results[metric] /= (idx + 1)
#调用_log_validation_metric_values
    self._log_validation_metric_values(current_iter, dataset_name, tb_logger)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
GFP-GAN是一种基于生成对抗网络的图像超分辨率重建方法,可以从低分辨率图像生成高分辨率图像。下面对GFP-GAN的源码进行简要分析。 GFP-GAN源码的主要组成部分包括生成器和判别器两个网络。生成器网络负责将给定的低分辨率图像作为输入,生成高分辨率图像作为输出。判别器网络则用于判断生成器生成的图像是否足够逼真。生成器和判别器网络通过对抗学习的方式进行训练,不断优化生成器的生成效果,使其生成的图像尽可能接近真实高分辨率图像。 GFP-GAN中使用了一种特殊的损失函数,包括感知损失和对抗损失。感知损失是通过计算生成图像与真实高分辨率图像之间的特征差异来衡量生成图像的质量。对抗损失则是通过判别器网络来评估生成器生成的图像是否逼真,鼓励生成器生成更真实的图像。 在源码中,可以看到生成器和判别器网络的结构定义和参数设置。还有训练过程中的数据处理部分,包括数据加载、预处理和模型训练等。此外,源码中可能还包含了一些辅助函数和工具函数,用于辅助训练和评估过程。 通过分析源码,可以深入了解GFP-GAN的具体实现细节和网络结构。同时,还可以对训练过程中的超参数设置、损失函数设计等进行调整和优化,以进一步提高GFP-GAN的生成效果和性能。 总之,通过对GFP-GAN源码分析,可以更好地理解该方法的原理和实现方式,为后续的研究和应用提供基础和参考。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值