盲超分率-元学习和KernelGAN结合-MetaKernelGAN-Meta-Learned Kernel For Blind Super-Resolution Kernel Estimation

盲超分率-元学习和KernelGAN结合-MetaKernelGAN-Meta-Learned Kernel For Blind Super-Resolution Kernel Estimation

论文链接:Meta-Learned Kernel For Blind Super-Resolution Kernel Estimation

源码链接: royson/metakernelgan

MetaKernelGAN通过结合元学习与KernelGAN的方法,实现了对模糊核的有效估计和高质量图像的恢复,从而有效解决了盲超分辨率问题。

主要涉及到一下两个重点前置知识

  1. KernelGAN:KernelGAN 通过引入一个特殊的网络结构来直接估计模糊核。在训练过程中,生成器不仅生成高分辨率图像,还利用估计的模糊核来对图像进行处理,从而提高生成图像的真实性。(可以查看上一篇文章:KernelGAN
  2. 元学习:通俗来说,就是“学习如何学习”。它的主要目标是提高机器学习模型在新任务上的学习能力和适应性,尤其是在训练样本有限的情况下。通过元学习,模型可以从之前的经验中快速学习,像人类一样更高效地解决新问题。本文使用的是元学习中的学习初始化(MAML)(可以查看小破站的一个视频,讲的很好(元学习是什么?))
元学习
元学习的简单理解

通过一下公式来做一个简单的理解,元学习分为内外层,具体训练逻辑如下:

  1. 首先外层会初始化一个w,这个作为算法参数,在外层更新,这个w会作为条件输入到内层循环中

  2. 内层用外层初始化的参数w在内层的支撑集上进行训练,得到在算法参数w的情况下最好内层参数o,用内层训练得到的参数作为外层的输入,在查询集上进行训练,如果w是好的算法参数,那么o在外层也能有好的表现,如果内层参数o不好,则说明是原来的外层参数w是不好的,则更新外层的算法参数w

  3. 更新算法参数w之后,循环上面的过程输入到内存进行内层参数更新,再更新外层。

  4. 如此循环后,w算法参数就能够得到适用于不同任务的算法参数,只需要输入一个类型任务,经过内层循环后,w就能够很快适用到该类型任务上。

在这里插入图片描述

主要的研究方向

论文使用的是学习初始化,下面介绍一些主要研究方向:学习优化(Learning to Optimize)学习初始化(Learning to Initialize)学习权重(Learning to Weight)学习奖励(Learning to Reward)学习增强(Learning to Augment)数据集蒸馏(Dataset Distillation)神经架构搜索(Neural Architecture Search)

学习优化是元学习的一个重要方向,旨在通过训练学习算法本身的优化过程。传统的优化算法(如 SGD、Adam)是手动设计的,而学习优化则试图通过元学习的方法来自动生成或改进优化算法,从而提高模型在新任务上的训练效率和效果。这一方向的研究通常涉及到设计优化器,使其能够快速适应不同的学习任务和数据分布。

本篇文章主要使用的学习初始化(Learning to Initialize),最经典的也就是下面这种模型MAML,学习初始化关注的是如何选择或学习模型参数的初始值,以便在训练新任务时能够更快地收敛。通过元学习,模型可以从多个任务的经验中获取初始化参数的有效性,从而在面对新的任务时,能够更迅速地找到适合的参数配置。这对于处理少样本学习等任务尤为重要,因为良好的初始化可以显著减少训练所需的时间和样本数量。

MAML

MAML的全称是Model-Agnostic Meta-Learning。它是一种元学习方法,旨在通过学习模型的初始化参数,使得模型能够在新任务上快速适应。

快速适应:MAML的核心思想是通过在多个任务上训练一个模型,使得这个模型在遇到新的任务时能够仅通过少量的梯度更新(通常是几次迭代)就快速适应。

模型无关:MAML的一个重要特点是“模型无关”,意味着它可以应用于各种类型的模型(如神经网络、线性模型等),并不局限于特定的架构或算法。

优化过程:在MAML中,模型在多个任务上进行内层训练(任务特定的更新),然后通过外层优化来调整模型的初始化参数。这使得模型在面对新任务时能够迅速收敛。

在这里插入图片描述

简单来说,就是算法参数和内层参数是同一个,内外层更新的都是同一个参数。

MetaKernelGAN

主要思想是使用元学习的方法来估计图像的模糊核(blur kernel)。它通过训练一个生成模糊核的模型,使得该模型能够适应多种类型的图像。当需要估计某一特定类型图像的模糊核时,只需提供少量的图像(甚至只是一张),模型就能够有效地生成适合该图像的模糊核,从而使得图像去模糊或模糊化处理效果达到最优。

MetaKernelGAN架构图

主要有内外层循环组成,内外层循环都在更新一个下采样器(downsampling),一个判别器(discriminator)

内层循环:仅计算下采样器和判别器的损失来更新下采样器和判别器。

外层循环:除了需要计算下采样器和判别器以外,还要需要计算核损失,来更新我们的模糊核。

在这里插入图片描述

从图中可以看出,对于外层参数的更新,是经过多步累计之后才对外层参数的进行更新的。下面将从源码层面剖析它的训练过程

数据集准备

在数据集制作过程中,首先从DIV2K数据集中采样高分辨率图像,将其裁剪为192×192的图像块。裁剪后的图像块会随机进行90度旋转、垂直翻转或水平翻转等数据增强操作,然后下采样生成低分辨率图像,作为用于超分辨率任务的输入数据。在元学习中同时学习生成器和判别器的参数,并将仅元学习生成器作为消融实验进行对比。

在实际制作的时候的,裁剪后的Patch是使用lmdb数据库,实际上就是生成两个数据文件,图片大小都是3x192x192,在执行之前会把这个文件的图片数据读取到dataloder,之后每个step取一张图训练。
加粗样式

外层训练

外层主要是先初始化下采样器和判别器的参数,外层的更新在内层更新之后,相当于内层更新完了参数之后,然后将更新后的参数拿到外层循环去进行训练, 来检测一下内层训练的怎么样,然后更新下采样器和判别器。不断的循环之后,最终外层会得到最好下采样器参数,同时也可以通过下采样器计算出最好的模糊核。
这里不算严格意义上的外层训练,就是简单的控制需要更新多少个任务,具体的外层训练在下面的另一张图里进行

# 这里不算严格意义上的外层训练,就是简单的控制需要更新多少个任务
def train(self):
    # 初始化训练数据迭代器,使用cycle函数使得迭代器无限循环遍历数据
    self.train_iter = iter(cycle(self.train_dataloader))
    
    # 确保训练步数 existing_step 小于总步数 steps,否则训练已经完成
    assert self.existing_step < self.steps, 'Training is done.'
    
    # 记录剩余的训练步数
    logger.info(f'{self.steps - self.existing_step} steps left.')

    # 从 existing_step 开始训练到 steps 结束
    for i in range(self.existing_step, self.steps):
        # 逐个将所有模型设为训练模式
        for model_typ, model in self.meta_models.items():
            model.train()  # 设置为训练模式
            self.optimizers[model_typ].zero_grad()  # 清空模型的梯度

        # 调用 meta_train 函数执行具体训练步骤,返回损失值
        loss = self.meta_train(i)

        # 更新所有优化器的参数
        for opt in self.optimizers.values():
            opt.step()
        
        # 更新学习率调度器
        for sch in self.schedulers.values():
            sch.step()

        # 记录训练的日志,包含当前的步数和损失
        log_msg = f'[Step {i+1}] Train error: {loss}, '            
        logger.debug(log_msg)

        # 每隔指定的步数(save_every)保存一次模型
        if (i + 1) % self.args.optim.save_every == 0:
            logger.info(log_msg)  # 输出日志信息
            self.save_ckp(i + 1, models=self.meta_models)  # 保存检查点

        # 记录当前的训练步数
        self.log_current_step(i + 1)

内层训练
  1. 首先从192x192大小的图片进行模糊化,生成两张图
    • patch_lrs_son_t(# 第二次下采样,作为生成模糊的标准图像,用来更新外层参数)—>gt_patch_son
    • patch_lrs_dad_t(# 第一次降采样,用来送入内层来更新内层参数)—>patch_dad
# 来源于下面代码 gt_patch_son, patch_dad, gt_kernel = self.sample_task()
def sample_task(self):
        '''
        具体实现时,以2X为例,是从输入图上,分别裁剪两个patch,一个给G网络,分辨率高些,代码里是64x64,这张图经过G网络之后,分辨率变为26x26,因为没有做padding,所以不是32x32,
        同时裁剪一个26x26的patch,是给D网络的,同时加入了少量的噪声,这样就构成了训练数据对。当然,裁剪patch时,并不是随机的,是通过构建了一个概率map,跟图像内容相关,
        应该是让尽量能取到边缘细节吧,G patch和D patch尽量纹理接近吧
        '''
        kernel = dutils.get_downsampling_ops(self.degradation_operation, train=True, scale=self.scale) # 将标准kernel记录下来
        logger.debug(f'Train task: kernel: {kernel.shape if kernel is not None else kernel}')
        labels = next(self.train_iter)
        # 第一次降采样,用来送入内层来更新内层参数
        patch_lrs_dad_t = fkp.batch_degradation(labels, kernel, np.array([self.scale, self.scale]), self.args.data.img_noise, device=self.device)
        print(patch_lrs_dad_t.shape)
        # 第二次下采样,作为生成模糊的标准图像,用来更新外层参数
        patch_lrs_son_t = fkp.batch_degradation(patch_lrs_dad_t, kernel, np.array([2, 2]), self.args.data.img_noise, device=self.device)
        print(patch_lrs_son_t.shape)

        if kernel is not None:
            kernel = torch.from_numpy(kernel).float()
        
        return patch_lrs_son_t, patch_lrs_dad_t, kernel
  1. 也就是上面采样的patch_lrs_dad_t,基于这张裁剪图进行内层更新,作为内层训练的输入(96x96)patch_dad,内层训练是现将patch_dad下采样得到patch_son,然后再patch_dadpatch_son在相对位置来进行随机裁剪,得到fake_d_inpatch_son得到)和 real_d_inpatch_dad得到),首先用fake_d_in 下采样器生成的图像来计算下采样生成器的loss值,然后更新下采样器的参数,然后用更新后的下采样器再对patch_dad重新进行下采样在得到patch_son,再从新的patch_son中找到原来fake_d_in的位置取出来得到g_out,然后让g_outreal_d_in来计算判别器的loss值,然后更新判别器,这样完成了一次内层更新,是基于patch_dad图像

在这里插入图片描述

def meta_train(self, step):
    # 初始化总损失和用于存储克隆后的模型字典 learners
    loss = 0
    learners = {}

    # 克隆每个 meta 模型,存入 learners 中
    for model_typ, model in self.meta_models.items():
        learners[model_typ] = model.clone()
    
    # 从任务中采样获取真实的低分辨率图像、父图像和真实核
    gt_patch_son, patch_dad, gt_kernel = self.sample_task()
    print(gt_patch_son.shape)
    print(patch_dad.shape)
    
    # 将数据传输到设备上(CPU或GPU)
    gt_patch_son, patch_dad, gt_kernel = \
            tutils.to_device(*[gt_patch_son, patch_dad, gt_kernel], cpu=self.args.sys.cpu)

    # 初始化验证过程中低分辨率子图像和判别器误差的列表
    v_lr_son_error = []
    v_dis_error = []
    
    # 计算 patch_dad 的面积
    area = np.prod(patch_dad.shape[-2:])
    
    # 通过多个适应步骤来调整模型
    for adapt_step in range(learners['downsampling'].optim_args.task_adapt_steps):
        # 使用克隆的下采样模型生成子图像 patch_son
        patch_son = learners['downsampling'](patch_dad)
        print(patch_son.shape)
        
        # 随机裁剪生成器输出的子图像和父图像,获取裁剪后的图像块和坐标
        fake_d_in, real_d_in, coords = tutils.random_image_crop(patch_son, \
                        img_dad=patch_dad, scale=self.scale, patch_size=self.dis_input_ps)
        print(fake_d_in.shape)  # 裁剪后的子图像
        print(real_d_in.shape)  # 裁剪后的父图像

        # 计算生成器损失(对抗损失)
        gen_error = self.compute_loss_generator(learners['discriminator'], fake_i=fake_d_in)
        
        # 添加正则化损失
        reg_error = self.compute_kernel_reg(learners['downsampling'])
        gen_error += reg_error
        
        # 更新下采样模块参数
        self.lr_decay(adapt_step, learners['downsampling'], area)
        learners['downsampling'].adapt(gen_error, first_order=learners['downsampling'].optim_args.first_order)

        # 更新判别器参数
        patch_son = learners['downsampling'](patch_dad)
        patch_son = patch_son.detach()
        
        # 根据生成器生成的坐标从子图像中提取图像块
        g_out = tutils.get_patch_using_coordinates(patch_son, coords, patch_size=self.dis_input_ps)
        print(g_out.shape)
        
        # 计算判别器损失(基于真实图像块和生成图像块)
        dis_error = self.compute_loss_discriminator(learners['discriminator'], real_i=real_d_in, fake_i=g_out)
        
        # 更新判别器参数
        learners['discriminator'].adapt(dis_error, first_order=learners['discriminator'].optim_args.first_order)
        
        # 在指定的步数进行验证
        if self.args.optim.validate_steps and (adapt_step + 1) in self.args.optim.validate_steps:
            logger.debug(f'Validating at step {adapt_step}')
            # 进行验证,计算当前模型的生成器和判别器误差
            g_er, d_er = self._validate_meta_train(learners, patch_dad, gt_patch_son, gt_kernel)
            v_lr_son_error.append(g_er)
            v_dis_error.append(d_er)

    # 如果没有指定验证步数,直接在最后一次计算外层误差
    if not self.args.optim.validate_steps:
        v_lr_son_error, v_dis_error = self._validate_meta_train(learners, patch_dad, gt_patch_son, gt_kernel)
    else:
        # 根据步骤获取损失权重并计算加权损失
        loss_weights = self._get_loss_weights(step).to(self.device)
        v_lr_son_error = torch.sum(loss_weights * torch.stack(v_lr_son_error).squeeze())
        if self.outer_lsgan_loss:
            v_dis_error = torch.sum(loss_weights * torch.stack(v_dis_error).squeeze())

    # 外层循环中的生成器损失计算与回传
    loss += v_lr_son_error.item()
    v_lr_son_error.backward(retain_graph=not learners['discriminator'].optim_args.first_order)
    
    # 如果不是一阶优化,则判别器的梯度归零
    if not learners['downsampling'].optim_args.first_order and 'discriminator' in learners:
        self.optimizers['discriminator'].zero_grad()

    # 外层循环中的判别器损失计算与回传
    if self.outer_lsgan_loss:
        loss += v_dis_error.item()
        self.optimizers['discriminator'].zero_grad()
        v_dis_error.backward()

    # 更新模型参数
    for model_typ, model in self.meta_models.items():
        if model.optim_args.task_opt.upper() == 'ADAM': 
            # 将克隆模型的梯度复制到原始模型中
            for model_p, clone_p in zip(model.parameters(), learners[model_typ].parameters()):
                if model_p.requires_grad:
                    model_p.grad = clone_p.grad.clone()

        # 如果任务设置了梯度裁剪,则进行梯度裁剪
        if model.optim_args.task_gradient_clip is not None and model.optim_args.task_gradient_clip > 0:
            nn.utils.clip_grad_value_(model.parameters(), model.optim_args.task_gradient_clip)
    
    return loss
  1. 内层更新完成以后,我们每隔一些步数就计算一次外层的loss值(这里使用patch_dadgt_patch_son来进行计算的,以gt_patch_son作为下采样之后的标准,通过采用下采样器采样patch_dad得到patch_son,让它和gt_patch_son来计算loss值更新参数),计算loss值的时候就不仅仅是下采样器和判别器了,还有kernel的loss值,方便后续在外层更新kernel的参数,但是不进行更新,只是记录下来,等到完成一次任务的所有内层步数之后,再对外层参数进行加权累加求loss值,更新外层参数,但是实际上由于是基于MAML的,所以外层参数还是下采样器和判别器,但是外层还需要获得一个可以适应多个任务的核,所以还需要更新kernel的参数。

在这里插入图片描述

def _validate_meta_train(self, learners, patch_dad, gt_patch_son, gt_kernel):
    # 使用 learners['downsampling'] 对输入的 patch_dad 进行下采样,得到 patch_son
    patch_son = learners['downsampling'](patch_dad)
    
    # 初始化验证过程中低分辨率图像的误差和总损失
    v_lr_son_error = 0
    loss = 0

    # 验证集上的训练,使用已经更新的下采样和判别器参数
    if self.gt_son_loss:    
        # 计算生成的低分辨率图像 patch_son 和真实低分辨率图像 gt_patch_son 之间的 L1 损失
        v_lr_son_error += self.l1_loss(patch_son, gt_patch_son)
        logger.debug(f'L1 LR_son Loss: {v_lr_son_error.item()}')

    if self.outer_lsgan_loss:
        # 随机裁剪生成器输出 patch_son,获取裁剪的图像块 d_in 和相应的坐标 coords
        d_in, _, coords = tutils.random_image_crop(patch_son, patch_size=self.dis_input_ps)
        # 计算生成器的损失(对抗损失)
        g_error = self.compute_loss_generator(learners['discriminator'], d_in)
        v_lr_son_error += g_error
        # 计算正则化误差,使用双三次插值损失作为正则化的一部分
        reg_error = self.compute_kernel_reg(learners['downsampling'], enable_bicubic_loss=True, g_input=patch_dad, g_output=patch_son)
        v_lr_son_error += reg_error
        logger.debug(f'Reg Error: {reg_error.item()}, Gen Error: {g_error.item()}')

    if self.kernel_loss:
        # 计算生成的退化核和真实核 gt_kernel 之间的误差
        k_error = self.compute_kernel_loss(learners['downsampling'], gt_kernel)
        v_lr_son_error += k_error
        logger.debug(f'Total Kernel Error: {k_error.item()}')

    logger.debug(f'Total Gen error: {v_lr_son_error.item()}')
    
    # 初始化判别器误差
    dis_error = 0
    if self.outer_lsgan_loss:
        # 使用生成器生成的图像(patch_son)和裁剪坐标,获取生成图像对应的图像块 g_out
        g_out = tutils.get_patch_using_coordinates(patch_son.detach(), coords, patch_size=self.dis_input_ps)
        # 获取真实图像 gt_patch_son 中与生成器输出对应的图像块 real_d_in
        real_d_in = tutils.get_patch_using_coordinates(gt_patch_son, coords, patch_size=self.dis_input_ps)
        # 计算判别器的损失,比较真实图像块和生成图像块
        dis_error = self.compute_loss_discriminator(learners['discriminator'], real_i=real_d_in, fake_i=g_out)
        logger.debug(f'Dis error: {dis_error.item()}')
    
    # 返回生成器误差和判别器误差
    return v_lr_son_error, dis_error

如需源码讲解可以联系我

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值