GFPGAN源码分析—第十篇

2021SC@SDUSC

源码:

models\gfpgan_model.py

本篇继续分析init.py与models\gfpgan_model.py下的

class GFPGANModel(BaseModel) 类get_roi_regions() 方法

目录

class GFPGANModel(BaseModel)

construct_img_pyramid(self)

get_roi_regions()

 _gram_mat(self, x)



class GFPGANModel(BaseModel)

construct_img_pyramid(self)

代码:

def construct_img_pyramid(self):
    pyramid_gt = [self.gt]
    down_img = self.gt
    for _ in range(0, self.log_size - 3):
        #对down_img进行数组采样操作
        down_img = F.interpolate(down_img, scale_factor=0.5, mode='bilinear', align_corners=False)
        #将down_img插入 pyramid_gt的最前面
        pyramid_gt.insert(0, down_img)
    return pyramid_gt

重点介绍一下F.interpolate即数组采样

作用:利用插值方法,对输入的张量数组进行上/下采样操作
F.interpolate的几个参数:
1.input(Tensor):需要进行采样处理的数组。
2.size(int或序列):输出空间的大小
3.scale_factor(float或序列):空间大小的乘数
4.mode(str):用于采样的算法。'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'。默认:'nearest'
5.align_corners(bool)
6.recompute_scale_facto(bool)

get_roi_regions()

参数:

self, eye_out_size=80, mouth_out_size=120

1.硬编码(hard cord)

rois_eyes = []
rois_mouths = []
for b in range(self.loc_left_eyes.size(0)):  # loop for batch size
    # left eye and right eye
    img_inds = self.loc_left_eyes.new_full((2, 1), b)
    #torch.stack()沿指定维度拼接
    bbox = torch.stack([self.loc_left_eyes[b, :], self.loc_right_eyes[b, :]], dim=0)  # shape: (2, 4)
    #torch.cat()对img_inds沿指定维度拼接
    rois = torch.cat([img_inds, bbox], dim=-1)  # shape: (2, 5)
    rois_eyes.append(rois)
    # mouse
    img_inds = self.loc_left_eyes.new_full((1, 1), b)
    #torch.cat()对img_inds沿指定维度拼接
    rois = torch.cat([img_inds, self.loc_mouths[b:b + 1, :]], dim=-1)  # shape: (1, 5)
    rois_mouths.append(rois)

rois_eyes = torch.cat(rois_eyes, 0).to(self.device)
rois_mouths = torch.cat(rois_mouths, 0).to(self.device)

在这里对比以下两种方法的区别

torch.cat():对tensors沿指定维度拼接,但返回的Tensor的维数不会变
torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维

3.real images

all_eyes = roi_align(self.gt, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
self.left_eyes_gt = all_eyes[0::2, :, :, :]
self.right_eyes_gt = all_eyes[1::2, :, :, :]
self.mouths_gt = roi_align(self.gt, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio

4.输出

all_eyes = roi_align(self.output, boxes=rois_eyes, output_size=eye_out_size) * face_ratio
self.left_eyes = all_eyes[0::2, :, :, :]
self.right_eyes = all_eyes[1::2, :, :, :]
self.mouths = roi_align(self.output, boxes=rois_mouths, output_size=mouth_out_size) * face_ratio

 _gram_mat(self, x)

用于计算格拉姆矩阵(Gram matrix),最后返回这样一个矩阵

参数:

x (torch.Tensor): Tensor with shape of (n, c, h, w).

代码:

n, c, h, w = x.size()
#调用view函数把原先tensor中的数据按照行优先的顺序排成一个一维的数据
features = x.view(n, c, w * h)
#交换输入张量 features 的两个维度
features_t = features.transpose(1, 2)
#计算
gram = features.bmm(features_t) / (c * h * w)
#返回
return gram

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
GFP-GAN是一种基于生成对抗网络的图像超分辨率重建方法,可以从低分辨率图像生成高分辨率图像。下面对GFP-GAN的源码进行简要分析。 GFP-GAN源码的主要组成部分包括生成器和判别器两个网络。生成器网络负责将给定的低分辨率图像作为输入,生成高分辨率图像作为输出。判别器网络则用于判断生成器生成的图像是否足够逼真。生成器和判别器网络通过对抗学习的方式进行训练,不断优化生成器的生成效果,使其生成的图像尽可能接近真实高分辨率图像。 GFP-GAN中使用了一种特殊的损失函数,包括感知损失和对抗损失。感知损失是通过计算生成图像与真实高分辨率图像之间的特征差异来衡量生成图像的质量。对抗损失则是通过判别器网络来评估生成器生成的图像是否逼真,鼓励生成器生成更真实的图像。 在源码中,可以看到生成器和判别器网络的结构定义和参数设置。还有训练过程中的数据处理部分,包括数据加载、预处理和模型训练等。此外,源码中可能还包含了一些辅助函数和工具函数,用于辅助训练和评估过程。 通过分析源码,可以深入了解GFP-GAN的具体实现细节和网络结构。同时,还可以对训练过程中的超参数设置、损失函数设计等进行调整和优化,以进一步提高GFP-GAN的生成效果和性能。 总之,通过对GFP-GAN源码分析,可以更好地理解该方法的原理和实现方式,为后续的研究和应用提供基础和参考。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值