2021SC@SDUSC
源码:
models\gfpgan_model.py
本篇分析models\gfpgan_model.py下的
class GFPGANModel(BaseModel) 类的部分方法
class GFPGANModel(BaseModel)
目录
test(self)
测试
def test(self):
#使用 with torch.no_grad():,强制之后的内容不进行计算图构建。
with torch.no_grad():
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
self.output, _ = self.net_g_ema(self.lq)
else:
logger = get_root_logger()
logger.warning('Do not have self.net_g_ema, use self.net_g.')
self.net_g.eval()
self.output, _ = self.net_g(self.lq)
self.net_g.train()
dist_validation()
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
#调用nondist_validation函数进行处理
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
nondist_validation()
参数:
self, dataloader, current_iter, tb_logger, save_img
分几步看一下代码
1.进度条与with_metrics的初始化
dataset_name = dataloader.dataset.opt['name']
#确认with_metrics is not None
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics:
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
#进度条
pbar = tqdm(total=len(dataloader), unit='image')
2.遍历dataloader,做fead data以及图像变换保存等
for idx, val_data in enumerate(dataloader):
#分离文件名与扩展名,返回一个元组。
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
#调用fead_data处理val_data
self.feed_data()
self.test()
#调用get_current_visuals
visuals = self.get_current_visuals()
#将torch张量转换为图像numpy数组
sr_img = tensor2img([visuals['sr']], min_max=(-1, 1))
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
#如果需要保存图片
if save_img:
#首先设置路径
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.png')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
imwrite(sr_img, save_img_path)
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
metric_data = dict(img1=sr_img, img2=gt_img)
self.metric_results[name] += calculate_metric(metric_data, opt_)
#更新进度条
pbar.update(1)
pbar.set_description(f'Test {img_name}')
pbar.close()
3.调用_log_validation_metric_values
#with_metrics一定为True
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
#调用_log_validation_metric_values
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)