2021SC@SDUSC
源码:
models\gfpgan_model.py
本篇分析models\gfpgan_model.py下的
class GFPGANModel(BaseModel) 类的最后几个方法
目录
save(self, epoch, current_iter)
class GFPGANModel(BaseModel)
_log_validation_metric_values
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}\n'
logger = get_root_logger()
logger.info(log_str)
#保存程序中的数据
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
get_current_visuals(self)
def get_current_visuals(self):
#创建记住插入顺序的字典
out_dict = OrderedDict()
out_dict['gt'] = self.gt.detach().cpu()
#移至cpu 返回值是cpu上的Tensor
out_dict['sr'] = self.output.detach().cpu()
return out_dict
save(self, epoch, current_iter)
网络保存
def save(self, epoch, current_iter):
#保存网络
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
self.save_network(self.net_d, 'net_d', current_iter)
# 保存组件鉴别器,主要是面部组件
if self.use_facial_disc:
self.save_network(self.net_d_left_eye, 'net_d_left_eye', current_iter)
self.save_network(self.net_d_right_eye, 'net_d_right_eye', current_iter)
self.save_network(self.net_d_mouth, 'net_d_mouth', current_iter)
#保存训练状态,可以用于恢复
self.save_training_state(epoch, current_iter)
介绍一下save_network与save_training_state函数的几个参数
save_network Args:
net (nn.Module | list[nn.Module]): .
net_label (str): 网络标签(Network label).
current_iter (int): Current iter number.
save_training_state Args:
epoch (int): Current epoch.
current_iter (int): Current iteration.