GFPGAN源码分析—第九篇

2021SC@SDUSC

源码:

models\gfpgan_model.py

本篇主要分析models\gfpgan_model.py下的

class GFPGANModel(BaseModel) 类部分方法

目录

GFPGANModel(BaseModel)

setup_optimizers(self)

feed_data(self, data)



GFPGANModel(BaseModel)

setup_optimizers(self)

本函数用于设置优化器(optimizers)

train_opt = self.opt['train']

1.优化器 g的实现

net_g_reg_ratio = 1
normal_params = []
for _, param in self.net_g.named_parameters():
    normal_params.append(param)
optim_params_g = [{  # add normal params first
    'params': normal_params,
    'lr': train_opt['optim_g']['lr']
}]
optim_type = train_opt['optim_g'].pop('type')
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
self.optimizers.append(self.optimizer_g)

2.优化器d的实现

net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
normal_params = []
for _, param in self.net_d.named_parameters():
    normal_params.append(param)
optim_params_d = [{  # add normal params first
    'params': normal_params,
    'lr': train_opt['optim_d']['lr']
}]
optim_type = train_opt['optim_d'].pop('type')
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
self.optimizers.append(self.optimizer_d)

3.为面部组件鉴别器设置优化器

if self.use_facial_disc:
    # 为面部组件鉴别器设置优化器
    optim_type = train_opt['optim_component'].pop('type')
    lr = train_opt['optim_component']['lr']
    # left eye
    self.optimizer_d_left_eye = self.get_optimizer(
        optim_type, self.net_d_left_eye.parameters(), lr, betas=(0.9, 0.99))
    self.optimizers.append(self.optimizer_d_left_eye)
    # right eye
    self.optimizer_d_right_eye = self.get_optimizer(
        optim_type, self.net_d_right_eye.parameters(), lr, betas=(0.9, 0.99))
    self.optimizers.append(self.optimizer_d_right_eye)
    # mouth
    self.optimizer_d_mouth = self.get_optimizer(
        optim_type, self.net_d_mouth.parameters(), lr, betas=(0.9, 0.99))
    self.optimizers.append(self.optimizer_d_mouth)

feed_data(self, data)

简单看一下代码

def feed_data(self, data):
    self.lq = data['lq'].to(self.device)
    if 'gt' in data:
        self.gt = data['gt'].to(self.device)

    if 'loc_left_eye' in data:
        # 获取面部组件的位置, shape (batch, 4)
        self.loc_left_eyes = data['loc_left_eye']
        self.loc_right_eyes = data['loc_right_eye']
        self.loc_mouths = data['loc_mouth']
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值